diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py index 760a5417f..8c84f99c6 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py @@ -13,27 +13,16 @@ def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]: """ Convert Bean Machine models to a JSON serializable object. - Args: samples (MonteCarloSamples): Output of a model from Bean Machine. - Returns Dict[str, List[List[float]]]: The JSON serializable object for use in the diagnostics tools. """ model = dict( sorted( - {str(key): value for key, value in samples.items()}.items(), + {str(key): value.tolist() for key, value in samples.items()}.items(), key=lambda item: item[0], ), ) - retval = {} - for node in model: - if model[node].ndim > 2: - # e.g. tensor of shape (4, 100, 3, 2) will generate a slice that's - # equivalent to [:, :, 0, 0] - slicer = (slice(None), slice(None)) + (0,) * (model[node].ndim - 2) - model[node] = model[node][slicer] - retval[node] = model[node].tolist() - - return retval + return model