From 339700ea005a3769158c882e063484f04f65b8b4 Mon Sep 17 00:00:00 2001 From: Andy Maloney Date: Mon, 24 Oct 2022 16:42:31 -0700 Subject: [PATCH] Trace tool (#1723) Summary: ### Motivation Continued work on the diagnostics tool, this includes a model trace tool. ### Changes proposed - Changes include new JavaScript and Python files for the Bokeh Application. - Updates to helper modules in stats for the JavaScript Pull Request resolved: https://github.com/facebookresearch/beanmachine/pull/1723 Test Plan: The tool was run in the Coin flipping tutorial. ### Types of changes - [ ] Docs change / refactoring / dependency upgrade - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ### Checklist - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **[CONTRIBUTING](https://github.com/facebookresearch/beanmachine/blob/main/CONTRIBUTING.md)** document. - [ ] I have added tests to cover my changes. - [ ] All new and existing tests passed. - [x] The title of my pull request is a short description of the requested changes. Reviewed By: feynmanliang Differential Revision: D39978113 Pulled By: horizon-blue fbshipit-source-id: a319ed8d20ffc45ca678203050768612294694a4 --- .../ppl/diagnostics/tools/__init__.py | 8 +- .../ppl/diagnostics/tools/js/package.json | 25 - .../diagnostics/tools/js/src/stats/array.ts | 89 +++ .../tools/js/src/stats/dataTransformation.ts | 48 ++ .../tools/js/src/stats/histogram.ts | 152 ++++ .../tools/js/src/trace/callbacks.ts | 220 ++++++ .../diagnostics/tools/js/src/trace/index.ts | 12 + .../tools/js/src/trace/interfaces.ts | 110 +++ .../ppl/diagnostics/tools/js/tsconfig.json | 1 - .../diagnostics/tools/js/webpack.config.js | 15 +- .../ppl/diagnostics/tools/js/yarn.lock | 141 +--- .../ppl/diagnostics/tools/marginal1d/tool.py | 10 +- .../diagnostics/tools/marginal1d/typing.py | 1 - .../ppl/diagnostics/tools/marginal1d/utils.py | 1 - .../ppl/diagnostics/tools/trace/__init__.py | 4 + .../ppl/diagnostics/tools/trace/tool.py | 161 ++++ .../ppl/diagnostics/tools/trace/typing.py | 157 ++++ .../ppl/diagnostics/tools/trace/utils.py | 731 ++++++++++++++++++ .../ppl/diagnostics/tools/utils/accessor.py | 19 +- .../tools/utils/diagnostic_tool_base.py | 18 +- .../tools/utils/model_serializers.py | 39 +- .../diagnostics/tools/utils/plotting_utils.py | 1 - src/beanmachine/ppl/diagnostics/tools/viz.py | 19 +- 23 files changed, 1775 insertions(+), 207 deletions(-) create mode 100644 src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts create mode 100644 src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts create mode 100644 src/beanmachine/ppl/diagnostics/tools/js/src/trace/index.ts create mode 100644 src/beanmachine/ppl/diagnostics/tools/js/src/trace/interfaces.ts create mode 100644 src/beanmachine/ppl/diagnostics/tools/trace/__init__.py create mode 100644 src/beanmachine/ppl/diagnostics/tools/trace/tool.py create mode 100644 src/beanmachine/ppl/diagnostics/tools/trace/typing.py create mode 100644 src/beanmachine/ppl/diagnostics/tools/trace/utils.py diff --git a/src/beanmachine/ppl/diagnostics/tools/__init__.py b/src/beanmachine/ppl/diagnostics/tools/__init__.py index 3e58d695f5..43cad634da 100644 --- a/src/beanmachine/ppl/diagnostics/tools/__init__.py +++ b/src/beanmachine/ppl/diagnostics/tools/__init__.py @@ -6,15 +6,19 @@ # flake8: noqa """Visual diagnostic tools for Bean Machine models.""" - import sys from pathlib import Path if sys.version_info >= (3, 8): + # NOTE: We need to import NotRequired from typing_extensions until PEP 655 is + # accepted, see https://peps.python.org/pep-0655/. This is to follow the + # interface objects in JavaScript that allow keys to not be required using ?. from typing import TypedDict + + from typing_extensions import NotRequired else: - from typing_extensions import TypedDict + from typing_extensions import NotRequired, TypedDict TOOLS_DIR = Path(__file__).parent.resolve() diff --git a/src/beanmachine/ppl/diagnostics/tools/js/package.json b/src/beanmachine/ppl/diagnostics/tools/js/package.json index 38e7770cb5..7518ebe720 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/package.json +++ b/src/beanmachine/ppl/diagnostics/tools/js/package.json @@ -13,8 +13,6 @@ "fast-kde": "^0.2.1" }, "devDependencies": { - "@types/node": "^18.0.4", - "@typescript-eslint/eslint-plugin": "^5.30.5", "@typescript-eslint/parser": "^5.30.5", "eslint": "^8.19.0", "eslint-config-airbnb": "^19.0.4", @@ -24,32 +22,9 @@ "eslint-plugin-prefer-arrow": "^1.2.3", "eslint-plugin-react": "^7.28.0", "eslint-plugin-react-hooks": "^4.3.0", - "prettier": "^2.7.1", "ts-loader": "^9.3.1", - "ts-node": "^10.9.1", "typescript": "^4.7.4", "webpack": "^5.74.0", "webpack-cli": "^4.10.0" - }, - "overrides": { - "cwise": "$cwise", - "minimist": "$minimist", - "quote-stream": "$quote-stream", - "static-eval": "$static-eval", - "static-module": "$static-module", - "typedarray-pool": "$typedarray-pool" - }, - "peerDependencies": { - "@types/cwise": "^1.0.4", - "@types/minimist": "^1.2.2", - "@types/static-eval": "^0.2.31", - "@types/typedarray-pool": "^1.1.2", - "buffer": "^6.0.3", - "cwise": "^1.0.10", - "minimist": "^1.2.6", - "quote-stream": "^1.0.2", - "static-eval": "2.1.0", - "static-module": "^3.0.4", - "typedarray-pool": "^1.2.0" } } diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts index d67b2e7dbf..39dcf65bd0 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts @@ -39,3 +39,92 @@ export const numericalSort = (data: number[]): number[] => { return a < b ? -1 : a > b ? 1 : 0; }); }; + +/** + * Determine the shape of the given array. + * + * @param {any[]} data - Any array of data. + * @returns {number[]} The shape of the data as an array. + */ +export const shape = (data: any[]): number[] => { + // From https://stackoverflow.com/questions/10237615/get-size-of-dimensions-in-array + const computeShape = (array: any[]): any[] => { + return array.length ? [...[array.length], ...computeShape(array[0])] : []; + }; + const arrayShape = computeShape(data); + // Remove the empty array that will exist at the end of the shape array, since it is + // the returned "else" value from above. + const dataShape = []; + for (let i = 0; i < arrayShape.length; i += 1) { + if (!Array.isArray(arrayShape[i])) { + dataShape.push(arrayShape[i]); + } + } + return dataShape; +}; + +/** + * Create an array that starts and stops with the given number of steps. + * + * @param {number} start - Where to start the array from. + * @param {number} stop - Where to stop the array. + * @param {number} [step] - The step size to take. + * @param {boolean} [closed] - Flag used to return a closed array or not. + * @param {null | number} [size] - If not null, then will return an array with the given + * size. + * @returns {number[]} An array that is linearly spaced between the start and stop + * values. + */ +export const linearRange = ( + start: number, + stop: number, + step: number = 1, + closed: boolean = true, + size: null | number = null, +): number[] => { + if (size !== null) { + step = (stop - start) / size; + } + let len = (stop - start) / step + 1; + if (!closed) { + len = (stop - start - step) / step + 1; + } + return Array.from({length: len}, (_, i) => { + return start + i * step; + }); +}; + +/** + * Return the indices that would sort the array. Follows NumPy's implementation. + * + * @param {number[]} data - The data to sort. + * @returns {number[]} An array of indices that would sort the original array. + */ +export const argSort = (data: number[]): number[] => { + const dataCopy = data.slice(0); + return dataCopy + .map((value, index) => { + return [value, index]; + }) + .sort((a, b) => { + return a[0] - b[0]; + }) + .map((value) => { + return value[1]; + }); +}; + +/** + * Count the number of time a value appears in an array. + * + * @param {number[]} data - The numeric array to count objects for. + * @returns {{[key: string]: number}} An object that contains the keys as the items in + * the original array, and values that are counts of the key. + */ +export const valueCounts = (data: number[]): {[key: string]: number} => { + const counts: {[key: string]: number} = {}; + for (let i = 0; i < data.length; i += 1) { + counts[data[i]] = (counts[data[i]] || 0) + 1; + } + return counts; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/dataTransformation.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/dataTransformation.ts index a8ecb46596..2c482a1b41 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/dataTransformation.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/dataTransformation.ts @@ -5,6 +5,8 @@ * LICENSE file in the root directory of this source tree. */ +import {argSort, valueCounts} from './array'; + /** * Scale the given array of numbers by the given scaleFactor. Note that this method * divides values in the given array by the scaleFactor. @@ -32,3 +34,49 @@ export const scaleToOne = (data: number[]): number[] => { const scaleFactor = Math.max(...data); return scaleBy(data, scaleFactor); }; + +/** + * Assign ranks to the given data. Follows SciPy's and ArviZ's implementations. + * + * @param {number[]} data - The numeric data to rank. + * @returns {number[]} An array of rankings. + */ +export const rankData = (data: number[]): number[] => { + const n = data.length; + const rank = Array(n); + const sortedIndex = argSort(data); + for (let i = 0; i < rank.length; i += 1) { + rank[sortedIndex[i]] = i + 1; + } + const counts = valueCounts(data); + const countsArray = Object.entries(counts); + const keys = []; + const keyCounts = []; + for (let i = 0; i < countsArray.length; i += 1) { + const [key, count] = countsArray[i]; + if (count > 1) { + keys.push(parseFloat(key)); + keyCounts.push(count); + } + } + for (let i = 0; i < keys.length; i += 1) { + const repeatIndices = []; + for (let j = 0; j < data.length; j += 1) { + if (data[j] === keys[i]) { + repeatIndices.push(j); + } + } + const rankValues = []; + for (let k = 0; k < repeatIndices.length; k += 1) { + rankValues.push(rank[repeatIndices[k]]); + } + const sum = rankValues.reduce((previousValue, currentValue) => { + return previousValue + currentValue; + }, 0.0); + const rankMean = sum / rankValues.length; + for (let k = 0; k < repeatIndices.length; k += 1) { + rank[repeatIndices[k]] = rankMean; + } + } + return rank; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts new file mode 100644 index 0000000000..40a4868571 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts @@ -0,0 +1,152 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {linearRange, numericalSort, shape} from './array'; +import {rankData, scaleToOne} from './dataTransformation'; +import {mean as computeMean} from './pointStatistic'; + +/** + * Compute the histogram of the given data. + * + * @param {number[]} data - Data to bin. + * @param {number} [numBins] - The number of bins to use for the histogram. If none is + * given, then we follow ArviZ's implementation by using twice then number of bins + * of the Sturges formula. + * @returns {number[][]} [TODO:description] + */ +export const calculateHistogram = (data: number[], numBins: number = 0): number[][] => { + const sortedData = numericalSort(data); + const numSamples = sortedData.length; + const dataMin = Math.min(...data); + const dataMax = Math.max(...data); + if (numBins === 0) { + numBins = Math.floor(Math.ceil(2 * Math.log2(numSamples)) + 1); + } + const binSize = + (dataMax - dataMin) / numBins === 0 ? 1 : (dataMax - dataMin) / numBins; + const bins = Array(numBins) + .fill([0, 0]) + .map((_, i) => { + return [i, 0]; + }); + + for (let i = 0; i < data.length; i += 1) { + const datum = sortedData[i]; + let binIndex = Math.floor((datum - dataMin) / binSize); + // Subtract 1 if the value lies on the last bin. + if (binIndex === numBins) { + binIndex -= 1; + } + bins[binIndex][1] += 1; + } + return bins; +}; + +export interface RankHistogram { + [key: string]: { + quad: { + left: number[]; + top: number[]; + right: number[]; + bottom: number[]; + chain: number[]; + draws: string[]; + rank: number[]; + }; + line: {x: number[]; y: number[]}; + chain: number[]; + rankMean: number[]; + mean: number[]; + }; +} + +/** + * A histogram of rank data. + * + * @param {number[][]} data - Raw random variable data for several chains. + * @returns {RankHistogram} A histogram of the data rankings. + */ +export const rankHistogram = (data: number[][]): RankHistogram => { + const [numChains, numDraws] = shape(data); + const numSamples = numChains * numDraws; + const flatData = data.flat(); + + // Calculate the rank of the data and ensure it is the same shape as the original + // data. + const rank = rankData(flatData); + const rankArray = []; + let start = Number.NaN; + let end = Number.NaN; + for (let i = 0; i < numChains; i += 1) { + if (i === 0) { + start = 0; + end = numDraws; + } else { + start = end; + end = (i + 1) * numDraws; + } + const chainRanks = rank.slice(start, end); + rankArray.push(chainRanks); + start = end; + end = (i + 1) * numDraws; + } + + // Calculate the number of bins needed. We will follow ArviZ and use twice the result + // using the Sturges' formula. + const numBins = Math.floor(Math.ceil(2 * Math.log2(numSamples)) + 1); + const lastBinEdge = Math.max(...rank); + + // Calculate the bin edges. Since the linearRange function computes a linear spacing + // of values between the start and end point, we need to ensure they are integer + // values. + let binEdges = linearRange(0, lastBinEdge, 1, true, numBins); + binEdges = binEdges.map((value) => { + return Math.ceil(value); + }); + + // Calculate the histograms of the rank data, and normalize it for each chain. + const output = {} as RankHistogram; + for (let i = 0; i < numChains; i += 1) { + const chainIndex = i + 1; + const chainName = `chain${chainIndex}`; + const chainRankHistogram = calculateHistogram(rankArray[i], numBins); + let counts = []; + for (let j = 0; j < chainRankHistogram.length; j += 1) { + counts.push(chainRankHistogram[j][1]); + } + counts = scaleToOne(counts); + const chainCounts = counts.map((value) => { + return value + i; + }); + + const chainRankMean = computeMean(chainCounts); + const left = binEdges.slice(0, binEdges.length - 1); + const right = binEdges.slice(1); + const binLabel = []; + for (let j = 0; j < left.length; j += 1) { + binLabel.push(`${left[j].toLocaleString()}-${right[j].toLocaleString()}`); + } + const x = linearRange(0, numSamples, 1); + const y = Array(x.length).fill(chainRankMean); + output[chainName] = { + quad: { + left: left, + top: chainCounts, + right: right, + bottom: Array(numBins).fill(i), + chain: Array(left.length).fill(i + 1), + draws: binLabel, + rank: counts, + }, + line: {x: x, y: y}, + chain: Array(x.length).fill(i + 1), + rankMean: Array(x.length).fill(chainIndex - chainRankMean), + mean: Array(x.length).fill(computeMean(counts)), + }; + } + return output; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts new file mode 100644 index 0000000000..1c4d4ab362 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts @@ -0,0 +1,220 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {Axis} from '@bokehjs/models/axes/axis'; +import * as interfaces from './interfaces'; +import {linearRange, shape} from '../stats/array'; +import {interval as hdiInterval} from '../stats/highestDensityInterval'; +import {rankHistogram} from '../stats/histogram'; +import {oneD} from '../stats/marginal'; +import {mean} from '../stats/pointStatistic'; + +const figureNames = ['marginals', 'forests', 'traces', 'ranks']; + +/** + * Update the given Bokeh Axis object with the new label string. You must use this + * method to update axis strings using TypeScript, otherwise the ts compiler will throw + * a type check error. + * + * @param {Axis} axis - The Bokeh Axis object needing a new label. + * @param {string | null} label - The new label for the Bokeh Axis object. + */ +export const updateAxisLabel = (axis: Axis, label: string | null): void => { + // Type check requirement. + if ('axis_label' in axis) { + axis.axis_label = label; + } +}; + +/** + * Compute data for the trace diagnostic tool. + * + * @param {number[][]} data - Raw random variable data from the model for all chains. + * @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when + * calculating the Kernel Density Estimate (KDE). + * @param {number} hdiProbability - The highest density interval probability to use when + * calculating the HDI. + * @returns {interfaces.Data} Data object that contains data for each figure including + * each chain. + */ +export const computeData = ( + data: number[][], + bwFactor: number, + hdiProbability: number, +): interfaces.Data => { + const [numChains, numDraws] = shape(data); + const output = {} as interfaces.Data; + for (let i = 0; i < figureNames.length; i += 1) { + const figureName = figureNames[i]; + if (figureName !== 'ranks') { + switch (figureName) { + case 'marginals': + output[figureName] = {} as interfaces.MarginalDataAllChains; + break; + case 'forests': + output[figureName] = {} as interfaces.ForestDataAllChains; + break; + case 'traces': + output[figureName] = {} as interfaces.TraceDataAllChains; + break; + default: + break; + } + for (let j = 0; j < numChains; j += 1) { + const chainIndex = j + 1; + const chainName = `chain${chainIndex}`; + const chainData = data[j]; + const marginal = oneD(chainData, bwFactor); + const marginalMean = mean(marginal.x); + let hdiBounds; + switch (figureName) { + case 'marginals': + output[figureName][chainName] = {} as interfaces.MarginalDataSingleChain; + output[figureName][chainName] = { + line: {x: marginal.x, y: marginal.y}, + chain: chainIndex, + mean: marginalMean, + bandwidth: marginal.bandwidth, + }; + break; + case 'forests': + output[figureName][chainName] = {} as interfaces.ForestDataSingleChain; + hdiBounds = hdiInterval(chainData, hdiProbability); + output[figureName][chainName] = { + line: { + x: [hdiBounds.lowerBound, hdiBounds.upperBound], + y: Array(2).fill(chainIndex), + }, + circle: {x: [marginalMean], y: [chainIndex]}, + chain: chainIndex, + mean: marginalMean, + }; + break; + case 'traces': + output[figureName][chainName] = {} as interfaces.TraceDataSingleChain; + output[figureName][chainName] = { + line: {x: linearRange(0, numDraws - 1, 1), y: chainData}, + chain: chainIndex, + mean: marginalMean, + }; + break; + default: + break; + } + } + } else if (figureName === 'ranks') { + output[figureName] = rankHistogram(data); + } + } + return output; +}; + +/** + * Callback used to update the Bokeh application in the notebook. + * + * @param {number[][]} data - Raw random variable data from the model for all chains. + * @param {string} rvName - The name of the random variable from the model. + * @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when + * calculating the kernel density estimate. + * @param {number} hdiProbability - The highest density interval probability to use when + * calculating the HDI. + * @param {interfaces.Sources} sources - Bokeh sources used to render glyphs in the + * application. + * @param {interfaces.Figures} figures - Bokeh figures shown in the application. + * @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs. + */ +export const update = ( + data: number[][], + rvName: string, + bwFactor: number, + hdiProbability: number, + sources: interfaces.Sources, + figures: interfaces.Figures, + tooltips: interfaces.Tooltips, +): void => { + const [numChains] = shape(data); + const computedData = computeData(data, bwFactor, hdiProbability); + for (let i = 0; i < figureNames.length; i += 1) { + const figureName = figureNames[i]; + const figure = figures[figureName]; + for (let j = 0; j < numChains; j += 1) { + const chainIndex = j + 1; + const chainName = `chain${chainIndex}`; + const chainData = computedData[figureName][chainName]; + const source = sources[figureName][chainName]; + switch (figureName) { + case 'marginals': + source.line.data = { + x: chainData.line.x, + y: chainData.line.y, + chain: Array(chainData.line.x.length).fill(chainData.chain), + mean: Array(chainData.line.x.length).fill(chainData.mean), + }; + updateAxisLabel(figure.below[0], rvName); + tooltips[figureName][j].tooltips = [ + ['Chain', '@chain'], + ['Mean', '@mean'], + [rvName, '@x'], + ]; + break; + case 'forests': + source.line.data = { + x: chainData.line.x, + y: chainData.line.y, + chain: Array(chainData.line.x.length).fill(chainData.chain), + mean: Array(chainData.line.x.length).fill(chainData.mean), + }; + source.circle.data = { + x: chainData.circle.x, + y: chainData.circle.y, + chain: [chainData.chain], + mean: [chainData.mean], + }; + updateAxisLabel(figure.below[0], rvName); + tooltips[figureName][j].tooltips = [ + ['Chain', '@chain'], + [rvName, '@mean'], + ]; + break; + case 'traces': + source.line.data = { + x: chainData.line.x, + y: chainData.line.y, + chain: Array(chainData.line.x.length).fill(chainData.chain), + mean: Array(chainData.line.x.length).fill(chainData.mean), + }; + updateAxisLabel(figure.left[0], rvName); + tooltips[figureName][j].tooltips = [ + ['Chain', '@chain'], + ['Mean', '@mean'], + [rvName, '@y'], + ]; + break; + case 'ranks': + source.line.data = { + x: chainData.line.x, + y: chainData.line.y, + chain: chainData.chain, + rankMean: chainData.rankMean, + }; + tooltips[figureName][j].line.tooltips = [ + ['Chain', '@chain'], + ['Rank mean', '@rankMean'], + ]; + source.quad.data = chainData.quad; + tooltips[figureName][j].quad.tooltips = [ + ['Chain', '@chain'], + ['Draws', '@draws'], + ['Rank', '@rank'], + ]; + break; + default: + break; + } + } + } +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/trace/index.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/index.ts new file mode 100644 index 0000000000..8b75182970 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/index.ts @@ -0,0 +1,12 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import * as trace from './callbacks'; + +// The CustomJS methods used by Bokeh require us to make the JavaScript available in the +// browser, which is done by defining it below. +(window as any).trace = trace; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/trace/interfaces.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/interfaces.ts new file mode 100644 index 0000000000..b7e1961889 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/interfaces.ts @@ -0,0 +1,110 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {ColumnDataSource} from '@bokehjs/models/sources/column_data_source'; +import {HoverTool} from '@bokehjs/models/tools/inspectors/hover_tool'; +import {Plot} from '@bokehjs/models/plots/plot'; +import {RankHistogram} from '../stats/histogram'; + +// NOTE: In the corresponding Python typing files for the diagnostic tool, we define +// similar types using a TypedDict object. TypeScript allows us to maintain +// semantic information about the key names and their types in the same object and +// still have the ability to loop over objects as long as we have the +// [key: string]: any; indicator in the interface definition. This boils down to a +// Python type of Dict[Any, Any], which again loses all type information about the +// object we are defining. We are mirroring what is done in Python here, so we +// keep the semantic information here at the expense of losing type information +// similarly to what is done in Python. + +export interface LineOrCircleGlyphData { + [key: string]: any; + x: number[]; + y: number[]; +} + +export interface MarginalDataSingleChain { + [key: string]: any; + line: LineOrCircleGlyphData; + chain: number; + mean: number; + bandwidth: number; +} + +export interface ForestDataSingleChain { + [key: string]: any; + line: LineOrCircleGlyphData; + circle: LineOrCircleGlyphData; + chain: number; + mean: number; +} + +export interface TraceDataSingleChain { + [key: string]: any; + line: LineOrCircleGlyphData; + chain: number; + mean: number; +} + +export interface MarginalDataAllChains { + [key: string]: MarginalDataSingleChain; +} + +export interface ForestDataAllChains { + [key: string]: ForestDataSingleChain; +} + +export interface TraceDataAllChains { + [key: string]: TraceDataSingleChain; +} + +export interface Data { + [key: string]: any; + marginals: MarginalDataAllChains; + forests: ForestDataAllChains; + traces: TraceDataAllChains; + ranks: RankHistogram; +} + +export interface SourceSingleChain { + line: ColumnDataSource; + circle?: ColumnDataSource; + quad?: ColumnDataSource; +} + +export interface SourceAllChains { + [key: string]: SourceSingleChain; +} + +export interface Sources { + [key: string]: any; + marginals: SourceAllChains; + forests: SourceAllChains; + traces: SourceAllChains; + ranks: SourceAllChains; +} + +export interface Figures { + [key: string]: any; + marginals: Plot; + forests: Plot; + traces: Plot; + ranks: Plot; +} + +export interface RankTooltips { + [key: string]: HoverTool; + line: HoverTool; + quad: HoverTool; +} + +export interface Tooltips { + [key: string]: any; + marginals: HoverTool[]; + forests: HoverTool[]; + traces: HoverTool[]; + ranks: RankTooltips[]; +} diff --git a/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json b/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json index 37de8788f8..2212d273f3 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json +++ b/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json @@ -23,7 +23,6 @@ "node_modules/@bokeh/bokehjs/build/js/lib/*", "node_modules/@bokeh/bokehjs/build/js/types/*" ], - "compute-histogram/*": ["node_modules/compute-histogram/*"], "fast-kde/*": ["node_modules/fast-kde/*"], "ndarray/*": ["node_modules/ndarray/*"], "ndarray-fft/*": [ diff --git a/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js b/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js index cbe79da94d..e13c2507f6 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js +++ b/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js @@ -10,24 +10,35 @@ const path = require('path'); module.exports = { entry: { marginal1d: './src/marginal1d/index.ts', + trace: './src/trace/index.ts', }, output: { filename: '[name].js', path: path.resolve(__dirname, 'dist'), }, module: { - rules: [{test: /\.ts$/, use: 'ts-loader', exclude: /node_modules/}], + rules: [ + { + test: /\.ts$/, + use: 'ts-loader', + exclude: /node_modules\/(?!(@bokeh\/bokehjs\/build\/js\/lib)\/).*/, + }, + ], }, target: 'web', mode: 'production', resolve: { - extensions: ['.ts'], + extensions: ['.ts', '.js'], modules: ['./stats', './interfaces', './types', './node_modules'], alias: { 'fast-kde/src/density1d': path.resolve( __dirname, 'node_modules/fast-kde/src/density1d.js', ), + '@bokehjs/models/ranges/range1d': path.resolve( + __dirname, + 'node_modules/@bokeh/bokehjs/build/js/lib/models/ranges/range1d.js', + ), }, }, optimization: { diff --git a/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock b/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock index 5ece43a2c8..7f512032f9 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock +++ b/src/beanmachine/ppl/diagnostics/tools/js/yarn.lock @@ -58,13 +58,6 @@ jquery-ui ">=1.8.0" tslib "^1.10.0" -"@cspotcode/source-map-support@^0.8.0": - version "0.8.1" - resolved "https://registry.yarnpkg.com/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz#00629c35a688e05a88b1cda684fb9d5e73f000a1" - integrity sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw== - dependencies: - "@jridgewell/trace-mapping" "0.3.9" - "@discoveryjs/json-ext@^0.5.0": version "0.5.7" resolved "https://registry.yarnpkg.com/@discoveryjs/json-ext/-/json-ext-0.5.7.tgz#1d572bfbbe14b7704e0ba0f39b74815b84870d70" @@ -113,7 +106,7 @@ "@jridgewell/sourcemap-codec" "^1.4.10" "@jridgewell/trace-mapping" "^0.3.9" -"@jridgewell/resolve-uri@3.1.0", "@jridgewell/resolve-uri@^3.0.3": +"@jridgewell/resolve-uri@3.1.0": version "3.1.0" resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78" integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w== @@ -136,14 +129,6 @@ resolved "https://registry.yarnpkg.com/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz#add4c98d341472a289190b424efbdb096991bb24" integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw== -"@jridgewell/trace-mapping@0.3.9": - version "0.3.9" - resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz#6534fd5933a53ba7cbf3a17615e273a0d1273ff9" - integrity sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ== - dependencies: - "@jridgewell/resolve-uri" "^3.0.3" - "@jridgewell/sourcemap-codec" "^1.4.10" - "@jridgewell/trace-mapping@^0.3.14", "@jridgewell/trace-mapping@^0.3.9": version "0.3.16" resolved "https://registry.yarnpkg.com/@jridgewell/trace-mapping/-/trace-mapping-0.3.16.tgz#a7982f16c18cae02be36274365433e5b49d7b23f" @@ -173,26 +158,6 @@ "@nodelib/fs.scandir" "2.1.5" fastq "^1.6.0" -"@tsconfig/node10@^1.0.7": - version "1.0.9" - resolved "https://registry.yarnpkg.com/@tsconfig/node10/-/node10-1.0.9.tgz#df4907fc07a886922637b15e02d4cebc4c0021b2" - integrity sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA== - -"@tsconfig/node12@^1.0.7": - version "1.0.11" - resolved "https://registry.yarnpkg.com/@tsconfig/node12/-/node12-1.0.11.tgz#ee3def1f27d9ed66dac6e46a295cffb0152e058d" - integrity sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag== - -"@tsconfig/node14@^1.0.0": - version "1.0.3" - resolved "https://registry.yarnpkg.com/@tsconfig/node14/-/node14-1.0.3.tgz#e4386316284f00b98435bf40f72f75a09dabf6c1" - integrity sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow== - -"@tsconfig/node16@^1.0.2": - version "1.0.3" - resolved "https://registry.yarnpkg.com/@tsconfig/node16/-/node16-1.0.3.tgz#472eaab5f15c1ffdd7f8628bd4c4f753995ec79e" - integrity sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ== - "@types/eslint-scope@^3.7.3": version "3.7.4" resolved "https://registry.yarnpkg.com/@types/eslint-scope/-/eslint-scope-3.7.4.tgz#37fc1223f0786c39627068a12e94d6e6fc61de16" @@ -226,7 +191,7 @@ dependencies: "@types/sizzle" "*" -"@types/json-schema@*", "@types/json-schema@^7.0.8", "@types/json-schema@^7.0.9": +"@types/json-schema@*", "@types/json-schema@^7.0.8": version "7.0.11" resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.11.tgz#d421b6c527a3037f7c84433fd2c4229e016863d3" integrity sha512-wOuvG1SN4Us4rez+tylwwwCV1psiNVOkJeM3AUWUNWg/jDQY2+HE/444y5gc+jBmRqASOm2Oeh5c1axHobwRKQ== @@ -236,7 +201,7 @@ resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee" integrity sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ== -"@types/node@*", "@types/node@^18.0.4": +"@types/node@*": version "18.8.5" resolved "https://registry.yarnpkg.com/@types/node/-/node-18.8.5.tgz#6a31f820c1077c3f8ce44f9e203e68a176e8f59e" integrity sha512-Bq7G3AErwe5A/Zki5fdD3O6+0zDChhg671NfPjtIcbtzDNZTv4NPKMRFr7gtYPG7y+B8uTiNK4Ngd9T0FTar6Q== @@ -253,20 +218,6 @@ dependencies: "@types/jquery" "*" -"@typescript-eslint/eslint-plugin@^5.30.5": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.40.0.tgz#0159bb71410eec563968288a17bd4478cdb685bd" - integrity sha512-FIBZgS3DVJgqPwJzvZTuH4HNsZhHMa9SjxTKAZTlMsPw/UzpEjcf9f4dfgDJEHjK+HboUJo123Eshl6niwEm/Q== - dependencies: - "@typescript-eslint/scope-manager" "5.40.0" - "@typescript-eslint/type-utils" "5.40.0" - "@typescript-eslint/utils" "5.40.0" - debug "^4.3.4" - ignore "^5.2.0" - regexpp "^3.2.0" - semver "^7.3.7" - tsutils "^3.21.0" - "@typescript-eslint/parser@^5.30.5": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/parser/-/parser-5.40.0.tgz#432bddc1fe9154945660f67c1ba6d44de5014840" @@ -285,16 +236,6 @@ "@typescript-eslint/types" "5.40.0" "@typescript-eslint/visitor-keys" "5.40.0" -"@typescript-eslint/type-utils@5.40.0": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/type-utils/-/type-utils-5.40.0.tgz#4964099d0158355e72d67a370249d7fc03331126" - integrity sha512-nfuSdKEZY2TpnPz5covjJqav+g5qeBqwSHKBvz7Vm1SAfy93SwKk/JeSTymruDGItTwNijSsno5LhOHRS1pcfw== - dependencies: - "@typescript-eslint/typescript-estree" "5.40.0" - "@typescript-eslint/utils" "5.40.0" - debug "^4.3.4" - tsutils "^3.21.0" - "@typescript-eslint/types@5.40.0": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/types/-/types-5.40.0.tgz#8de07e118a10b8f63c99e174a3860f75608c822e" @@ -313,19 +254,6 @@ semver "^7.3.7" tsutils "^3.21.0" -"@typescript-eslint/utils@5.40.0": - version "5.40.0" - resolved "https://registry.yarnpkg.com/@typescript-eslint/utils/-/utils-5.40.0.tgz#647f56a875fd09d33c6abd70913c3dd50759b772" - integrity sha512-MO0y3T5BQ5+tkkuYZJBjePewsY+cQnfkYeRqS6tPh28niiIwPnQ1t59CSRcs1ZwJJNOdWw7rv9pF8aP58IMihA== - dependencies: - "@types/json-schema" "^7.0.9" - "@typescript-eslint/scope-manager" "5.40.0" - "@typescript-eslint/types" "5.40.0" - "@typescript-eslint/typescript-estree" "5.40.0" - eslint-scope "^5.1.1" - eslint-utils "^3.0.0" - semver "^7.3.7" - "@typescript-eslint/visitor-keys@5.40.0": version "5.40.0" resolved "https://registry.yarnpkg.com/@typescript-eslint/visitor-keys/-/visitor-keys-5.40.0.tgz#dd2d38097f68e0d2e1e06cb9f73c0173aca54b68" @@ -492,12 +420,7 @@ acorn-jsx@^5.3.2: resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937" integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ== -acorn-walk@^8.1.1: - version "8.2.0" - resolved "https://registry.yarnpkg.com/acorn-walk/-/acorn-walk-8.2.0.tgz#741210f2e2426454508853a2f44d0ab83b7f69c1" - integrity sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA== - -acorn@^8.4.1, acorn@^8.5.0, acorn@^8.7.1, acorn@^8.8.0: +acorn@^8.5.0, acorn@^8.7.1, acorn@^8.8.0: version "8.8.0" resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.0.tgz#88c0187620435c7f6015803f5539dae05a9dbea8" integrity sha512-QOxyigPVrpZ2GXT+PFyZTl6TtOFc5egxHIP9IlQ+RbupQuX4RkT/Bee4/kQuC02Xkzg84JcT7oLYtDIQxp+v7w== @@ -529,11 +452,6 @@ ansi-styles@^4.1.0: dependencies: color-convert "^2.0.1" -arg@^4.1.0: - version "4.1.3" - resolved "https://registry.yarnpkg.com/arg/-/arg-4.1.3.tgz#269fc7ad5b8e42cb63c896d5666017261c144089" - integrity sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA== - argparse@^2.0.1: version "2.0.1" resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38" @@ -729,11 +647,6 @@ core-js-pure@^3.25.1: resolved "https://registry.yarnpkg.com/core-js-pure/-/core-js-pure-3.25.5.tgz#79716ba54240c6aa9ceba6eee08cf79471ba184d" integrity sha512-oml3M22pHM+igfWHDfdLVq2ShWmjM2V4L+dQEBs0DWVIqEm9WHCwGAlZ6BmyBQGy5sFrJmcx+856D9lVKyGWYg== -create-require@^1.1.0: - version "1.1.1" - resolved "https://registry.yarnpkg.com/create-require/-/create-require-1.1.1.tgz#c1d7e8f1e5f6cfc9ff65f9cd352d37348756c333" - integrity sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ== - cross-spawn@^7.0.2, cross-spawn@^7.0.3: version "7.0.3" resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6" @@ -800,11 +713,6 @@ define-properties@^1.1.3, define-properties@^1.1.4: has-property-descriptors "^1.0.0" object-keys "^1.1.1" -diff@^4.0.1: - version "4.0.2" - resolved "https://registry.yarnpkg.com/diff/-/diff-4.0.2.tgz#60f3aecb89d5fae520c11aa19efc2bb982aade7d" - integrity sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A== - dir-glob@^3.0.1: version "3.0.1" resolved "https://registry.yarnpkg.com/dir-glob/-/dir-glob-3.0.1.tgz#56dbf73d992a4a93ba1584f4534063fd2e41717f" @@ -1082,7 +990,7 @@ eslint-plugin-react@^7.28.0: semver "^6.3.0" string.prototype.matchall "^4.0.7" -eslint-scope@5.1.1, eslint-scope@^5.1.1: +eslint-scope@5.1.1: version "5.1.1" resolved "https://registry.yarnpkg.com/eslint-scope/-/eslint-scope-5.1.1.tgz#e786e59a66cb92b3f6c1fb0d508aab174848f48c" integrity sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw== @@ -1775,11 +1683,6 @@ lru-cache@^6.0.0: dependencies: yallist "^4.0.0" -make-error@^1.1.1: - version "1.3.6" - resolved "https://registry.yarnpkg.com/make-error/-/make-error-1.3.6.tgz#2eb2e37ea9b67c4891f684a1394799af484cf7a2" - integrity sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw== - mathjax-full@^3.2.0: version "3.2.2" resolved "https://registry.yarnpkg.com/mathjax-full/-/mathjax-full-3.2.2.tgz#43f02e55219db393030985d2b6537ceae82f1fa7" @@ -2053,11 +1956,6 @@ prelude-ls@^1.2.1: resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.2.1.tgz#debc6489d7a6e6b0e7611888cec880337d316396" integrity sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g== -prettier@^2.7.1: - version "2.7.1" - resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.7.1.tgz#e235806850d057f97bb08368a4f7d899f7760c64" - integrity sha512-ujppO+MkdPqoVINuDFDRLClm7D78qbDt0/NR+wp5FqEZOoTNAjPHWj17QRhu7geIHJfcNhRk1XVQmF8Bp3ye+g== - proj4@^2.7.5: version "2.8.0" resolved "https://registry.yarnpkg.com/proj4/-/proj4-2.8.0.tgz#b2cb8f3ccd56d4dcc7c3e46155cd02caa804b170" @@ -2412,25 +2310,6 @@ ts-loader@^9.3.1: micromatch "^4.0.0" semver "^7.3.4" -ts-node@^10.9.1: - version "10.9.1" - resolved "https://registry.yarnpkg.com/ts-node/-/ts-node-10.9.1.tgz#e73de9102958af9e1f0b168a6ff320e25adcff4b" - integrity sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw== - dependencies: - "@cspotcode/source-map-support" "^0.8.0" - "@tsconfig/node10" "^1.0.7" - "@tsconfig/node12" "^1.0.7" - "@tsconfig/node14" "^1.0.0" - "@tsconfig/node16" "^1.0.2" - acorn "^8.4.1" - acorn-walk "^8.1.1" - arg "^4.1.0" - create-require "^1.1.0" - diff "^4.0.1" - make-error "^1.1.1" - v8-compile-cache-lib "^3.0.1" - yn "3.1.1" - tsconfig-paths@^3.14.1: version "3.14.1" resolved "https://registry.yarnpkg.com/tsconfig-paths/-/tsconfig-paths-3.14.1.tgz#ba0734599e8ea36c862798e920bcf163277b137a" @@ -2515,11 +2394,6 @@ uri-js@^4.2.2: dependencies: punycode "^2.1.0" -v8-compile-cache-lib@^3.0.1: - version "3.0.1" - resolved "https://registry.yarnpkg.com/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz#6336e8d71965cb3d35a1bbb7868445a7c05264bf" - integrity sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg== - watchpack@^2.4.0: version "2.4.0" resolved "https://registry.yarnpkg.com/watchpack/-/watchpack-2.4.0.tgz#fa33032374962c78113f93c7f2fb4c54c9862a5d" @@ -2642,11 +2516,6 @@ yallist@^4.0.0: resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A== -yn@3.1.1: - version "3.1.1" - resolved "https://registry.yarnpkg.com/yn/-/yn-3.1.1.tgz#1e87401a09d767c1d5eab26a6e4c185182d2eb50" - integrity sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q== - yocto-queue@^0.1.0: version "0.1.0" resolved "https://registry.yarnpkg.com/yocto-queue/-/yocto-queue-0.1.0.tgz#0294eb3dee05028d31ee1a5fa2c556a6aaf10a1b" diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py index e6ce876e1c..1e122cba87 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py @@ -4,8 +4,7 @@ # LICENSE file in the root directory of this source tree. """Marginal 1D diagnostic tool for a Bean Machine model.""" - -from typing import TypeVar +from __future__ import annotations from beanmachine.ppl.diagnostics.tools.marginal1d import utils from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( @@ -16,9 +15,6 @@ from bokeh.models.callbacks import CustomJS -T = TypeVar("T", bound="Marginal1d") - - class Marginal1d(DiagnosticToolBaseClass): """ Marginal 1D diagnostic tool. @@ -40,10 +36,10 @@ class Marginal1d(DiagnosticToolBaseClass): independently from a Python server. """ - def __init__(self: T, mcs: MonteCarloSamples) -> None: + def __init__(self: Marginal1d, mcs: MonteCarloSamples) -> None: super(Marginal1d, self).__init__(mcs) - def create_document(self: T) -> Model: + def create_document(self: Marginal1d) -> Model: # Initialize widget values using Python. rv_name = self.rv_names[0] bw_factor = 1.0 diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py index 881cd308e4..8796cc70e3 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Marginal 1D diagnostic tool types for a Bean Machine model.""" - from typing import Any, Dict, List, Union from beanmachine.ppl.diagnostics.tools import TypedDict diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py index 6ef4dfebf7..ff148175c3 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Methods used to generate the diagnostic tool.""" - from typing import List import numpy as np diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/__init__.py b/src/beanmachine/ppl/diagnostics/tools/trace/__init__.py new file mode 100644 index 0000000000..7bec24cb17 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/trace/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/tool.py b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py new file mode 100644 index 0000000000..710acc38b5 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Trace diagnostic tool for a Bean Machine model.""" +from __future__ import annotations + +from beanmachine.ppl.diagnostics.tools.trace import utils +from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( + DiagnosticToolBaseClass, +) +from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples +from bokeh.models import Model +from bokeh.models.callbacks import CustomJS + + +class Trace(DiagnosticToolBaseClass): + """Trace tool. + + Args: + mcs (MonteCarloSamples): The return object from running a Bean Machine model. + + Attributes: + data (Dict[str, List[List[float]]]): JSON serializable representation of the + given `mcs` object. + rv_names (List[str]): The list of random variables string names for the given + model. + num_chains (int): The number of chains of the model. + num_draws (int): The number of draws of the model for each chain. + palette (List[str]): A list of color values used for the glyphs in the figures. + The colors are specifically chosen from the Colorblind palette defined in + Bokeh. + tool_js (str):The JavaScript callbacks needed to render the Bokeh tool + independently from a Python server. + """ + + def __init__(self: Trace, mcs: MonteCarloSamples) -> None: + super(Trace, self).__init__(mcs) + + def create_document(self: Trace) -> Model: + # Initialize widget values using Python. + rv_name = self.rv_names[0] + + # NOTE: We are going to use Python and Bokeh to render the tool in the notebook + # output cell, however, we WILL NOT use Python to calculate any of the + # statistics displayed in the tool. We do this so we can make the BROWSER + # run all the calculations based on user interactions. If we did not + # employ this strategy, then the initial display a user would receive + # would be calculated by Python, and any subsequent updates would be + # calculated by JavaScript. The side-effect of having two backends + # calculate data could cause the figures to flicker, which would not be a + # good end user experience. + # + # Bokeh 3.0 is implementing an "on load" feature, which would nullify this + # requirement, and until that version is released, we have to employ this + # work-around. + + # Create empty Bokeh sources using Python. + sources = utils.create_sources(num_chains=self.num_chains) + + # Create empty figures for the tool using Python. + figures = utils.create_figures(rv_name=rv_name, num_chains=self.num_chains) + + # Create empty glyphs and attach them to the figures using Python. + glyphs = utils.create_glyphs(num_chains=self.num_chains) + utils.add_glyphs(sources=sources, figures=figures, glyphs=glyphs) + + # Create empty annotations and attach them to the figures using Python. + annotations = utils.create_annotations( + figures=figures, + num_chains=self.num_chains, + ) + utils.add_annotations(figures=figures, annotations=annotations) + + # Create empty tool tips and attach them to the figures using Python. + tooltips = utils.create_tooltips( + figures=figures, + rv_name=rv_name, + num_chains=self.num_chains, + ) + utils.add_tooltips(figures=figures, tooltips=tooltips) + + # Create the widgets for the tool using Python. + widgets = utils.create_widgets(rv_names=self.rv_names, rv_name=rv_name) + + # Create the view of the tool and serialize it into HTML using static resources + # from Bokeh. Embedding the tool in this manner prevents external CDN calls for + # JavaScript resources, and prevents the user from having to know where the + # Bokeh server is. + tool_view = utils.create_view(figures=figures, widgets=widgets) + + # Create callbacks for the tool using JavaScript. + callback_js = f""" + const rvName = widgets.rv_select.value; + const rvData = data[rvName]; + let bw = 0.0; + // Remove the CSS classes that dim the tool output on initial load. + const toolTab = toolView.tabs[0]; + const toolChildren = toolTab.child.children; + const dimmedComponent = toolChildren[1]; + dimmedComponent.css_classes = []; + try {{ + trace.update( + rvData, + rvName, + bwFactor, + hdiProbability, + sources, + figures, + tooltips, + ); + }} catch (error) {{ + {self.tool_js} + trace.update( + rvData, + rvName, + bwFactor, + hdiProbability, + sources, + figures, + tooltips, + ); + }} + """ + + # Each widget requires the following dictionary for the CustomJS method. Notice + # that the callback_js object above uses the names defined as keys in the below + # object with values defined by the Python objects. + callback_arguments = { + "data": self.data, + "widgets": widgets, + "sources": sources, + "figures": figures, + "tooltips": tooltips, + "toolView": tool_view, + } + + # Each widget requires slightly different JS. + rv_select_js = f""" + const bwFactor = 1.0; + const hdiProbability = 0.89; + widgets.bw_factor_slider.value = bwFactor; + widgets.hdi_slider.value = 100 * hdiProbability; + {callback_js}; + figures.marginals.reset.emit(); + """ + slider_js = f""" + const bwFactor = widgets.bw_factor_slider.value; + const hdiProbability = widgets.hdi_slider.value / 100; + {callback_js}; + """ + slider_callback = CustomJS(args=callback_arguments, code=slider_js) + rv_select_callback = CustomJS(args=callback_arguments, code=rv_select_js) + + # Tell Python to use the JavaScript. + widgets["rv_select"].js_on_change("value", rv_select_callback) + widgets["bw_factor_slider"].js_on_change("value", slider_callback) + widgets["hdi_slider"].js_on_change("value", slider_callback) + + return tool_view diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/typing.py b/src/beanmachine/ppl/diagnostics/tools/trace/typing.py new file mode 100644 index 0000000000..1841bc7c70 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/trace/typing.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Trace diagnostic tool types for a Bean Machine model.""" +from typing import Any, Dict, List, Union + +from beanmachine.ppl.diagnostics.tools import NotRequired, TypedDict +from bokeh.models.annotations import Legend +from bokeh.models.glyphs import Circle, Line, Quad +from bokeh.models.sources import ColumnDataSource +from bokeh.models.tools import HoverTool +from bokeh.models.widgets.inputs import Select +from bokeh.models.widgets.sliders import Slider +from bokeh.plotting.figure import Figure + + +# NOTE: These are the types pyre gives us when using `reveal_type(...)` on the outputs +# of the methods. +Data = Dict[str, Dict[Any, Any]] +Sources = Dict[Any, Any] +Figures = Dict[Any, Any] +Glyphs = Dict[Any, Any] +Annotations = Dict[str, Dict[str, Legend]] +Tooltips = Dict[Any, Any] +Widgets = Dict[str, Union[Select, Slider]] + + +# NOTE: TypedDict objects are for reference only. Due to the way pyre accesses keys in +# dictionaries, and how NumPy casts arrays when using tolist(), we are unable to +# use them, but they provide semantic information for the different types. We must +# ignore a lot of lines due to the issue discussed here +# https://pyre-check.org/docs/errors/#13-uninitialized-attribute. + + +class _LineOrCircleGlyphData(TypedDict): # pyre-ignore + x: List[float] + y: List[float] + + +class _QuadGlyphData(TypedDict): # pyre-ignore + """Follow the RankHistogram interface in stats/histogram.js.""" + + left: List[float] + top: List[float] + right: List[float] + bottom: List[float] + chain: List[int] + draws: List[str] + rank: List[float] + + +class _MarginalDataSingleChain(TypedDict): # pyre-ignore + line: _LineOrCircleGlyphData + chain: int + mean: float + bandwidth: float + + +class _ForestDataSingleChain(TypedDict): # pyre-ignore + line: _LineOrCircleGlyphData + circle: _LineOrCircleGlyphData + chain: int + mean: float + + +class _TraceDataSingleChain(TypedDict): # pyre-ignore + line: _LineOrCircleGlyphData + chain: int + mean: float + + +class _RankDataSingleChain(TypedDict): # pyre-ignore + quad: _QuadGlyphData + line: _LineOrCircleGlyphData + chain: List[int] + rankMean: List[float] + mean: List[float] + + +_MarginalDataAllChains = Dict[str, _MarginalDataSingleChain] +_ForestDataAllChains = Dict[str, _ForestDataSingleChain] +_TraceDataAllChains = Dict[str, _TraceDataSingleChain] +_RankDataAllChains = Dict[str, _RankDataSingleChain] + + +class _Data(TypedDict): # pyre-ignore + marginals: _MarginalDataAllChains + forests: _ForestDataAllChains + traces: _TraceDataAllChains + ranks: _RankDataAllChains + + +class _SourceSingleChain(TypedDict): # pyre-ignore + line: ColumnDataSource + circle: NotRequired[ColumnDataSource] + quad: NotRequired[ColumnDataSource] + + +_SourceAllChains = Dict[str, _SourceSingleChain] + + +class _Sources(TypedDict): # pyre-ignore + marginals: _SourceAllChains + forests: _SourceAllChains + traces: _SourceAllChains + ranks: _SourceAllChains + + +class _Figures(TypedDict): # pyre-ignore + marginals: Figure + forests: Figure + traces: Figure + ranks: Figure + + +class _RankTooltip(TypedDict): # pyre-ignore + line: HoverTool + quad: HoverTool + + +class _Tooltips(TypedDict): # pyre-ignore + marginals: List[HoverTool] + forests: List[HoverTool] + traces: List[HoverTool] + ranks: List[_RankTooltip] + + +class _Glyph(TypedDict): # pyre-ignore + glyph: Union[Circle, Line, Quad] + hover_glyph: Union[Circle, Line, Quad] + + +class _GlyphSingleChain(TypedDict): # pyre-ignore + line: _Glyph + circle: NotRequired[_Glyph] + quad: NotRequired[_Glyph] + + +_GlyphAllChains = Dict[str, _GlyphSingleChain] + + +class _Glyphs(TypedDict): # pyre-ignore + marginals: _GlyphAllChains + forests: _GlyphAllChains + traces: _GlyphAllChains + ranks: _GlyphAllChains + + +_Annotations = Dict[str, Dict[str, Legend]] + + +class _Widgets(TypedDict): # pyre-ignore + rv_select: Select + bw_factor_slider: Slider + hdi_slider: Slider diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/utils.py b/src/beanmachine/ppl/diagnostics/tools/trace/utils.py new file mode 100644 index 0000000000..df99a5cb0a --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/trace/utils.py @@ -0,0 +1,731 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Methods used to generate the diagnostic tool.""" +from typing import List + +from beanmachine.ppl.diagnostics.tools.trace import typing +from beanmachine.ppl.diagnostics.tools.utils import plotting_utils +from bokeh.core.property.wrappers import PropertyValueList +from bokeh.models.annotations import Legend, LegendItem +from bokeh.models.glyphs import Circle, Line, Quad +from bokeh.models.layouts import Column, Row +from bokeh.models.sources import ColumnDataSource +from bokeh.models.tools import HoverTool +from bokeh.models.widgets.inputs import Select +from bokeh.models.widgets.markups import Div +from bokeh.models.widgets.panels import Panel, Tabs +from bokeh.models.widgets.sliders import Slider +from bokeh.plotting.figure import figure + +PLOT_WIDTH = 400 +PLOT_HEIGHT = 500 +TRACE_PLOT_WIDTH = 600 +FIGURE_NAMES = ["marginals", "forests", "traces", "ranks"] +# Define what the empty data object looks like in order to make the browser handle all +# computations. +EMPTY_DATA = {} + + +def create_empty_data(num_chains: int) -> typing.Data: + """Create an empty data object for the tool. + + We do not know a priori how many chains a model will have, so we use this method to + build an empty data object with the given number of chains. + + Parameters + ---------- + num_chains : int + The number of chains from the model. + + Returns + ------- + typing.Data + An empty data object to be filled by JavaScript. + """ + output = { + "marginals": {}, + "forests": {}, + "traces": {}, + "ranks": {}, + } + for chain in range(num_chains): + chain_index = chain + 1 + chain_name = f"chain{chain_index}" + marginal = { + "line": {"x": [], "y": []}, + "chain": [], + "mean": [], + "bandwidth": [], + } + forest = { + "line": {"x": [], "y": []}, + "circle": {"x": [], "y": []}, + "chain": [], + "mean": [], + } + trace = { + "line": {"x": [], "y": []}, + "chain": [], + "mean": [], + } + rank = { + "quad": { + "left": [], + "top": [], + "right": [], + "bottom": [], + "chain": [], + "draws": [], + "rank": [], + }, + "line": {"x": [], "y": []}, + "chain": [], + "rankMean": [], + "mean": [], + } + single_chain_data = [marginal, forest, trace, rank] + chain_data = dict(zip(FIGURE_NAMES, single_chain_data)) + for figure_name in FIGURE_NAMES: + output[figure_name][chain_name] = chain_data[figure_name] + return output + + +def create_sources(num_chains: int) -> typing.Sources: + """Create Bokeh sources from the given data that will be bound to glyphs. + + Parameters + ---------- + num_chains : int + The number of chains from the model. + + Returns + ------- + typing.Sources + A dictionary of Bokeh ColumnDataSource objects. + """ + global EMPTY_DATA + if not EMPTY_DATA: + EMPTY_DATA = create_empty_data(num_chains=num_chains) + + output = {} + for figure_name, figure_data in EMPTY_DATA.items(): + output[figure_name] = {} + for chain_name, chain_data in figure_data.items(): + output[figure_name][chain_name] = {} + if figure_name == "marginals": + output[figure_name][chain_name]["line"] = ColumnDataSource( + { + "x": chain_data["line"]["x"], + "y": chain_data["line"]["y"], + "chain": chain_data["chain"], + "mean": chain_data["mean"], + }, + ) + if figure_name == "forests": + output[figure_name][chain_name]["line"] = ColumnDataSource( + { + "x": chain_data["line"]["x"], + "y": chain_data["line"]["y"], + }, + ) + output[figure_name][chain_name]["circle"] = ColumnDataSource( + { + "x": chain_data["circle"]["x"], + "y": chain_data["circle"]["y"], + "chain": chain_data["chain"], + }, + ) + if figure_name == "traces": + output[figure_name][chain_name]["line"] = ColumnDataSource( + { + "x": chain_data["line"]["x"], + "y": chain_data["line"]["y"], + "chain": chain_data["chain"], + "mean": chain_data["mean"], + }, + ) + if figure_name == "ranks": + output[figure_name][chain_name]["line"] = ColumnDataSource( + { + "x": chain_data["line"]["x"], + "y": chain_data["line"]["y"], + "chain": chain_data["chain"], + "rankMean": chain_data["rankMean"], + }, + ) + output[figure_name][chain_name]["quad"] = ColumnDataSource( + { + "left": chain_data["quad"]["left"], + "top": chain_data["quad"]["top"], + "right": chain_data["quad"]["right"], + "bottom": chain_data["quad"]["bottom"], + "chain": chain_data["chain"], + "draws": chain_data["quad"]["draws"], + "rank": chain_data["quad"]["rank"], + }, + ) + return output + + +def create_figures(rv_name: str, num_chains: int) -> typing.Figures: + """Create the Bokeh figures used for the tool. + + Parameters + ---------- + rv_name : str + The string representation of the random variable data. + num_chains : int + The number of chains from the model. + + Returns + ------- + typing.Figures + A dictionary of Bokeh Figure objects. + """ + output = {} + for figure_name in FIGURE_NAMES: + fig = figure( + width=PLOT_WIDTH, + height=PLOT_HEIGHT, + outline_line_color="black", + sizing_mode="scale_both", + ) + plotting_utils.style_figure(fig) + # NOTE: There are several figures where we do not want the x-axis to change its + # limits. This is why we set the x_range to an object from Bokeh called + # Range1d. + if figure_name == "marginals": + fig.title = "Marginal" + fig.xaxis.axis_label = rv_name + fig.yaxis.visible = False + elif figure_name == "forests": + fig.title = "Forest" + fig.xaxis.axis_label = rv_name + fig.yaxis.axis_label = "Chain" + fig.yaxis.minor_tick_line_color = None + fig.yaxis.ticker.desired_num_ticks = num_chains + elif figure_name == "traces": + fig.title = "Trace" + fig.xaxis.axis_label = "Draw from single chain" + fig.yaxis.axis_label = rv_name + fig.width = TRACE_PLOT_WIDTH + elif figure_name == "ranks": + fig.title = "Rank" + fig.xaxis.axis_label = "Rank from all chains" + fig.yaxis.axis_label = "Chain" + fig.width = TRACE_PLOT_WIDTH + fig.yaxis.minor_tick_line_color = None + fig.yaxis.ticker.desired_num_ticks = num_chains + output[figure_name] = fig + return output + + +def create_glyphs(num_chains: int) -> typing.Glyphs: + """Create the glyphs used for the figures of the tool. + + Parameters + ---------- + num_chains : int + The number of chains from the model. + + Returns + ------- + typing.Glyphs + A dictionary of Bokeh Glyphs objects. + """ + global EMPTY_DATA + if not EMPTY_DATA: + EMPTY_DATA = create_empty_data(num_chains=num_chains) + + palette = plotting_utils.choose_palette(num_colors=num_chains) + output = {} + for figure_name, figure_data in EMPTY_DATA.items(): + output[figure_name] = {} + for i, (chain_name, _) in enumerate(figure_data.items()): + output[figure_name][chain_name] = {} + color = palette[i] + if figure_name == "marginals": + output[figure_name][chain_name]["line"] = { + "glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=0.7, + line_width=2.0, + name=f"{figure_name}{chain_name.title()}LineGlyph", + ), + "hover_glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=1.0, + line_width=2.0, + name=f"{figure_name}{chain_name.title()}LineHoverGlyph", + ), + } + elif figure_name == "forests": + output[figure_name][chain_name] = { + "line": { + "glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=0.7, + line_width=2.0, + name=f"{figure_name}{chain_name.title()}LineGlyph", + ), + "hover_glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=1.0, + line_width=2.0, + name=f"{figure_name}{chain_name.title()}LineHoverGlyph", + ), + }, + "circle": { + "glyph": Circle( + x="x", + y="y", + size=10, + fill_color=color, + fill_alpha=0.7, + line_color="white", + name=f"{figure_name}{chain_name.title()}CircleGlyph", + ), + "hover_glyph": Circle( + x="x", + y="y", + size=10, + fill_color=color, + fill_alpha=1.0, + line_color="black", + name=f"{figure_name}{chain_name.title()}CircleHoverGlyph", + ), + }, + } + if figure_name == "traces": + output[figure_name][chain_name]["line"] = { + "glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=0.6, + line_width=0.6, + name=f"{figure_name}{chain_name.title()}LineGlyph", + ), + "hover_glyph": Line( + x="x", + y="y", + line_color=color, + line_alpha=0.6, + line_width=1.0, + name=f"{figure_name}{chain_name.title()}LineHoverGlyph", + ), + } + if figure_name == "ranks": + output[figure_name][chain_name] = { + "quad": { + "glyph": Quad( + left="left", + top="top", + right="right", + bottom="bottom", + fill_color=color, + fill_alpha=0.7, + line_color="white", + name=f"{figure_name}{chain_name.title()}QuadGlyph", + ), + "hover_glyph": Quad( + left="left", + top="top", + right="right", + bottom="bottom", + fill_color=color, + fill_alpha=1.0, + line_color="black", + name=f"{figure_name}{chain_name.title()}QuadHoverGlyph", + ), + }, + "line": { + "glyph": Line( + x="x", + y="y", + line_color="grey", + line_alpha=0.7, + line_width=3.0, + line_dash="dashed", + name=f"{figure_name}{chain_name.title()}LineGlyph", + ), + "hover_glyph": Line( + x="x", + y="y", + line_color="grey", + line_alpha=1.0, + line_width=3.0, + line_dash="solid", + name=f"{figure_name}{chain_name.title()}LineGlyph", + ), + }, + } + return output + + +def add_glyphs( + figures: typing.Figures, + glyphs: typing.Glyphs, + sources: typing.Sources, +) -> None: + """Bind source data to glyphs and add the glyphs to the given figures. + + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + glyphs : typing.Glyphs + A dictionary of Bokeh Glyphs objects. + sources : typing.Sources + A dictionary of Bokeh ColumnDataSource objects. + + Returns + ------- + None + Adds data bound glyphs to the given figures directly. + """ + for figure_name, figure_sources in sources.items(): + fig = figures[figure_name] + for chain_name, source in figure_sources.items(): + chain_glyphs = glyphs[figure_name][chain_name] + # NOTE: Every figure has a line glyph, so we always add it here. + fig.add_glyph( + source_or_glyph=source["line"], + glyph=chain_glyphs["line"]["glyph"], + hover_glyph=chain_glyphs["line"]["hover_glyph"], + name=chain_glyphs["line"]["glyph"].name, + ) + # We want to keep the x-axis from moving when changing queries, so we add + # the bounds below from the marginal figure. All figures that need to keep + # its range stable are linked to the marginal figure's range below. + if figure_name == "marginals": + pass + elif figure_name == "forests": + fig.add_glyph( + source_or_glyph=source["circle"], + glyph=chain_glyphs["circle"]["glyph"], + hover_glyph=chain_glyphs["circle"]["hover_glyph"], + name=chain_glyphs["circle"]["glyph"].name, + ) + elif figure_name == "ranks": + fig.add_glyph( + source_or_glyph=source["quad"], + glyph=chain_glyphs["quad"]["glyph"], + hover_glyph=chain_glyphs["quad"]["hover_glyph"], + name=chain_glyphs["quad"]["glyph"].name, + ) + # Link figure ranges together. + figures["forests"].x_range = figures["marginals"].x_range + + +def create_annotations(figures: typing.Figures, num_chains: int) -> typing.Annotations: + """Create any annotations for the figures of the tool. + + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + num_chains : int + The number of chains of the model. + + Returns + ------- + typing.Annotations + A dictionary of Bokeh Annotation objects. + """ + renderers = [] + for _, fig in figures.items(): + renderers.extend(PropertyValueList(fig.renderers)) + legend_items = [] + for chain in range(num_chains): + chain_index = chain + 1 + chain_name = f"chain{chain_index}" + legend_items.append( + LegendItem( + renderers=[ + renderer + for renderer in renderers + if chain_name in renderer.name.lower() + ], + label=chain_name, + ), + ) + legend = Legend( + items=legend_items, + orientation="horizontal", + border_line_color="black", + click_policy="hide", + ) + output = {"traces": {"legend": legend}, "ranks": {"legend": legend}} + return output + + +def add_annotations(figures: typing.Figures, annotations: typing.Annotations) -> None: + """Add the given annotations to the given figures of the tool. + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + annotations : typing.Annotations + A dictionary of Bokeh Annotation objects. + Returns + ------- + None + Adds annotations directly to the given figures. + """ + for figure_name, figure_annotations in annotations.items(): + fig = figures[figure_name] + for _, annotation in figure_annotations.items(): + fig.add_layout(annotation, "below") + + +def create_tooltips( + rv_name: str, + figures: typing.Figures, + num_chains: int, +) -> typing.Tooltips: + """Create hover tools for the glyphs used in the figures of the tool. + + Parameters + ---------- + rv_name : str + The string representation of the random variable data. + figures : typing.Figures + A dictionary of Bokeh Figure objects. + num_chains : int + The number of chains of the model. + + Returns + ------- + typing.Tooltips + A dictionary of Bokeh HoverTools objects. + """ + output = {} + for figure_name, fig in figures.items(): + output[figure_name] = [] + for chain in range(num_chains): + chain_index = chain + 1 + chain_name = f"chain{chain_index}" + if figure_name == "marginals": + glyph_name = f"{figure_name}{chain_name.title()}LineGlyph" + output[figure_name].append( + HoverTool( + renderers=plotting_utils.filter_renderers(fig, glyph_name), + tooltips=[ + ("Chain", "@chain"), + ("Mean", "@mean"), + (rv_name, "@x"), + ], + ), + ) + if figure_name == "forests": + glyph_name = f"{figure_name}{chain_name.title()}CircleGlyph" + output[figure_name].append( + HoverTool( + renderers=plotting_utils.filter_renderers(fig, glyph_name), + tooltips=[ + ("Chain", "@chain"), + (rv_name, "@x"), + ], + ), + ) + if figure_name == "traces": + glyph_name = f"{figure_name}{chain_name.title()}LineGlyph" + output[figure_name].append( + HoverTool( + renderers=plotting_utils.filter_renderers(fig, glyph_name), + tooltips=[ + ("Chain", "@chain"), + ("Mean", "@mean"), + (rv_name, "@y"), + ], + ), + ) + if figure_name == "ranks": + output[figure_name].append( + { + "line": HoverTool( + renderers=plotting_utils.filter_renderers( + fig, + f"{figure_name}{chain_name.title()}LineGlyph", + ), + tooltips=[("Chain", "@chain"), ("Rank mean", "@rankMean")], + ), + "quad": HoverTool( + renderers=plotting_utils.filter_renderers( + fig, + f"{figure_name}{chain_name.title()}QuadGlyph", + ), + tooltips=[ + ("Chain", "@chain"), + ("Draws", "@draws"), + ("Rank", "@rank"), + ], + ), + }, + ) + return output + + +def add_tooltips(figures: typing.Figures, tooltips: typing.Tooltips) -> None: + """Add the given tools to the figures. + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + tooltips : typing.Tooltips + A dictionary of Bokeh HoverTools objects. + Returns + ------- + None + Adds the tooltips directly to the given figures. + """ + for figure_name, fig in figures.items(): + for tips in tooltips[figure_name]: + if figure_name == "ranks": + for _, tips_ in tips.items(): + fig.add_tools(tips_) + else: + fig.add_tools(tips) + + +def create_widgets(rv_names: List[str], rv_name: str) -> typing.Widgets: + """Create the widgets used in the tool. + + Parameters + ---------- + rv_names : List[str] + A list of all available random variable names. + rv_name : str + The string representation of the random variable data. + + Returns + ------- + typing.Widgets + A dictionary of Bokeh widget objects. + """ + output = { + "rv_select": Select(value=rv_name, options=rv_names, title="Query"), + "bw_factor_slider": Slider( + start=0.01, + end=2.00, + step=0.01, + value=1.0, + title="Bandwidth factor", + ), + "hdi_slider": Slider(start=1, end=99, step=1, value=89, title="HDI"), + } + return output + + +def help_page() -> Div: + """Help tab for the tool. + Returns + ------- + Div + Bokeh Div widget containing the help tab information. + """ + text = """ +

Rank plots

+

+ Rank plots are a histogram of the samples over time. All samples across + all chains are ranked and then we plot the average rank for each chain on + regular intervals. If the chains are mixing well this histogram should + look roughly uniform. If it looks highly irregular that suggests chains + might be getting stuck and not adequately exploring the sample space. + See the paper by Vehtari et al for more information. +

+

Trace plots

+

+ The more familiar trace plots are also included in this widget. You can + click on the legend to show/hide different chains and compare them to the + rank plots. +

+ + """ + return Div(text=text, disable_math=False, min_width=PLOT_WIDTH) + + +def create_view(figures: typing.Figures, widgets: typing.Widgets) -> Tabs: + """Create the tool view. + + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + widgets : typing.Widgets + A dictionary of Bokeh widget objects. + + Returns + ------- + Tabs + Bokeh Tabs objects. + """ + toolbar = plotting_utils.create_toolbar(list(figures.values())) + help_panel = Panel(child=help_page(), title="Help", name="helpPanel") + marginal_panel = Panel( + child=Column( + children=[figures["marginals"], widgets["bw_factor_slider"]], + sizing_mode="scale_both", + ), + title="Marginals", + ) + forest_panel = Panel( + child=Column( + children=[figures["forests"], widgets["hdi_slider"]], + sizing_mode="scale_both", + ), + title="HDIs", + ) + left_panels = Tabs(tabs=[marginal_panel, forest_panel], sizing_mode="scale_both") + trace_panel = Panel( + child=Column(children=[figures["traces"]], sizing_mode="scale_both"), + title="Traces", + ) + rank_panel = Panel( + child=Column(children=[figures["ranks"]], sizing_mode="scale_both"), + title="Ranks", + ) + right_panels = Tabs(tabs=[trace_panel, rank_panel], sizing_mode="scale_both") + tool_panel = Panel( + child=Column( + children=[ + widgets["rv_select"], + Row( + children=[left_panels, right_panels, toolbar], + sizing_mode="scale_both", + ), + ], + sizing_mode="scale_both", + ), + title="Trace tool", + ) + return Tabs(tabs=[tool_panel, help_panel], sizing_mode="scale_both") diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py b/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py index e4e23f6bc6..9146b3c8aa 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/accessor.py @@ -10,17 +10,15 @@ - `pandas`: https://github.com/pandas-dev/pandas/blob/main/pandas/core/accessor.py - `xarray`: https://github.com/pydata/xarray/blob/main/xarray/core/extensions.py """ +from __future__ import annotations import contextlib import warnings -from typing import Callable, TypeVar +from typing import Callable from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples -T = TypeVar("T", bound="CachedAccessor") - - class CachedAccessor: """ A descriptor for caching accessors. @@ -38,11 +36,11 @@ class CachedAccessor: object. """ - def __init__(self: T, name: str, accessor: object) -> None: + def __init__(self: CachedAccessor, name: str, accessor: object) -> None: self._name = name self._accessor = accessor - def __get__(self: T, obj: object, cls: object) -> object: + def __get__(self: CachedAccessor, obj: object, cls: object) -> object: """ Method to retrieve the accessor namespace. @@ -123,7 +121,8 @@ def register_mcs_accessor(name: str) -> Callable: object. Example: - >>> from typing import Dict, List, TypeVar + >>> from __future__ import annotations + >>> from typing import Dict, List >>> >>> import beanmachine.ppl as bm >>> import numpy as np @@ -132,8 +131,6 @@ def register_mcs_accessor(name: str) -> Callable: >>> from beanmachine.ppl.diagnostics.tools.utils import accessor >>> from torch import tensor >>> - >>> T = TypeVar("T", bound="MagicAccessor") - >>> >>> @bm.random_variable >>> def alpha(): >>> return dist.Normal(0, 1) @@ -144,9 +141,9 @@ def register_mcs_accessor(name: str) -> Callable: >>> >>> @accessor.register_mcs_accessor("magic") >>> class MagicAccessor: - >>> def __init__(self: T, mcs: MonteCarloSamples) -> None: + >>> def __init__(self: MagicAccessor, mcs: MonteCarloSamples) -> None: >>> self.mcs = mcs - >>> def show_me(self: T) -> Dict[str, List[List[float]]]: + >>> def show_me(self: MagicAccessor) -> Dict[str, List[List[float]]]: >>> # Return a JSON serializable object from a MonteCarloSamples object. >>> return dict( >>> sorted( diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py b/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py index 9d79084b10..d1c3c0488d 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. """Base class for diagnostic tools of a Bean Machine model.""" +from __future__ import annotations import re from abc import ABC, abstractmethod -from typing import Any, Mapping, TypeVar +from typing import Any, Mapping from beanmachine.ppl.diagnostics.tools import JS_DIST_DIR from beanmachine.ppl.diagnostics.tools.utils import plotting_utils @@ -18,9 +19,6 @@ from bokeh.resources import INLINE -T = TypeVar("T", bound="DiagnosticToolBaseClass") - - class DiagnosticToolBaseClass(ABC): """ Base class for visual diagnostic tools. @@ -43,7 +41,7 @@ class DiagnosticToolBaseClass(ABC): """ @abstractmethod - def __init__(self: T, mcs: MonteCarloSamples) -> None: + def __init__(self: DiagnosticToolBaseClass, mcs: MonteCarloSamples) -> None: self.data = serialize_bm(mcs) self.rv_names = ["Select a random variable..."] + list(self.data.keys()) self.num_chains = mcs.num_chains @@ -51,7 +49,7 @@ def __init__(self: T, mcs: MonteCarloSamples) -> None: self.palette = plotting_utils.choose_palette(self.num_chains) self.tool_js = self.load_tool_js() - def load_tool_js(self: T) -> str: + def load_tool_js(self: DiagnosticToolBaseClass) -> str: """ Load the JavaScript for the diagnostic tool. @@ -75,7 +73,7 @@ def load_tool_js(self: T) -> str: tool_js = f.read() return tool_js - def show(self: T) -> None: + def show(self: DiagnosticToolBaseClass) -> None: """ Show the diagnostic tool in the notebook. @@ -97,7 +95,7 @@ def show(self: T) -> None: html = file_html(doc, resources=INLINE, template=self.html_template()) display(HTML(html)) - def html_template(self: T) -> str: + def html_template(self: DiagnosticToolBaseClass) -> str: """ HTML template object used to inject CSS styles for Bokeh Applications. @@ -145,11 +143,11 @@ def html_template(self: T) -> str: """ @abstractmethod - def create_document(self: T) -> Model: + def create_document(self: DiagnosticToolBaseClass) -> Model: """To be implemented by the inheriting class.""" ... - def _tool_json(self: T) -> Mapping[Any, Any]: + def _tool_json(self: DiagnosticToolBaseClass) -> Mapping[Any, Any]: """ Debugging method used primarily when creating a new diagnostic tool. diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py index 109b588ea0..27f1d1f8be 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Collection of serializers for the diagnostics tool use.""" - from typing import Dict, List from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples @@ -13,18 +12,40 @@ def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]: """ Convert Bean Machine models to a JSON serializable object. - Args: samples (MonteCarloSamples): Output of a model from Bean Machine. - Returns Dict[str, List[List[float]]]: The JSON serializable object for use in the diagnostics tools. """ - model = dict( - sorted( - {str(key): value.tolist() for key, value in samples.items()}.items(), - key=lambda item: item[0], - ), - ) + rv_identifiers = list(samples.keys()) + reshaped_data = {} + for rv_identifier in rv_identifiers: + rv_data = samples[rv_identifier] + rv_shape = rv_data.shape + num_rv_chains = rv_shape[0] + reshaped_data[f"{str(rv_identifier)}"] = [] + for rv_chain in range(num_rv_chains): + chain_data = rv_data[rv_chain, :] + chain_shape = chain_data.shape + if len(chain_shape) > 3 and 1 not in list(chain_shape): + msg = ( + "Unable to handle data with dimensionality larger than " "mxnxkx1." + ) + raise ValueError(msg) + elif len(chain_shape) == 3 and 1 in list(chain_shape): + if chain_shape[1] == 1 in list(chain_shape): + reshape_dimensions = chain_shape[2] + else: + reshape_dimensions = chain_shape[1] + for i, reshape_dimension in enumerate(range(reshape_dimensions)): + data = rv_data[rv_chain, :, reshape_dimension].reshape(-1) + if f"{str(rv_identifier)}[{i}]" not in reshaped_data: + reshaped_data[f"{str(rv_identifier)}[{i}]"] = [] + reshaped_data[f"{str(rv_identifier)}[{i}]"].append(data.tolist()) + elif len(chain_shape) == 1: + reshaped_data[f"{str(rv_identifier)}"].append( + rv_data[rv_chain, :].tolist(), + ) + model = dict(sorted(reshaped_data.items(), key=lambda item: item[0])) return model diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py b/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py index 82b48f2a1f..c116ef778a 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/plotting_utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. """Plotting utilities for the diagnostic tools.""" - from typing import List from bokeh.core.property.nullable import Nullable diff --git a/src/beanmachine/ppl/diagnostics/tools/viz.py b/src/beanmachine/ppl/diagnostics/tools/viz.py index eb3e627c55..c149ba9d7e 100644 --- a/src/beanmachine/ppl/diagnostics/tools/viz.py +++ b/src/beanmachine/ppl/diagnostics/tools/viz.py @@ -49,7 +49,24 @@ def __init__(self: DiagnosticsTools, mcs: MonteCarloSamples) -> None: @_requires_dev_packages def marginal1d(self: DiagnosticsTools) -> None: - """Marginal 1D tool.""" + """ + Marginal 1D diagnostic tool for a Bean Machine model. + + Returns: + None: Displays the tool directly in a Jupyter notebook. + """ from beanmachine.ppl.diagnostics.tools.marginal1d.tool import Marginal1d Marginal1d(self.mcs).show() + + @_requires_dev_packages + def trace(self: DiagnosticsTools) -> None: + """ + Trace diagnostic tool for a Bean Machine model. + + Returns: + None: Displays the tool directly in a Jupyter notebook. + """ + from beanmachine.ppl.diagnostics.tools.trace.tool import Trace + + Trace(self.mcs).show()