Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Revert the hacky change for N-d tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
horizon-blue committed Oct 18, 2022
1 parent 51ccc88 commit 967f79e
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
Expand All @@ -13,27 +14,16 @@
def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]:
"""
Convert Bean Machine models to a JSON serializable object.
Args:
samples (MonteCarloSamples): Output of a model from Bean Machine.
Returns
Dict[str, List[List[float]]]: The JSON serializable object for use in the
diagnostics tools.
"""
model = dict(
sorted(
{str(key): value for key, value in samples.items()}.items(),
{str(key): value.tolist() for key, value in samples.items()}.items(),
key=lambda item: item[0],
),
)
retval = {}
for node in model:
if model[node].ndim > 2:
# e.g. tensor of shape (4, 100, 3, 2) will generate a slice that's
# equivalent to [:, :, 0, 0]
slicer = (slice(None), slice(None)) + (0,) * (model[node].ndim - 2)
model[node] = model[node][slicer]
retval[node] = model[node].tolist()

return retval
return model

0 comments on commit 967f79e

Please sign in to comment.