Skip to content

Commit

Permalink
Add batched calculation option to energy_score_empirical in order t…
Browse files Browse the repository at this point in the history
…o reduce memory consumption (#3402)

* Add batched calculation option to energy_score_empirical in order to reduce memory consumption.

* Replace native Python sum with torch stack(...).sum().

---------

Co-authored-by: Ben Zickel <[email protected]>
  • Loading branch information
BenZickel and Ben Zickel authored Sep 26, 2024
1 parent 0d3243a commit 04c371f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
50 changes: 44 additions & 6 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import math
import numbers
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from torch.fft import irfft, rfft
Expand Down Expand Up @@ -510,7 +510,9 @@ def crps_empirical(pred, truth):
return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2


def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
def energy_score_empirical(
pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None
) -> torch.Tensor:
r"""
Computes negative Energy Score ES* (see equation 22 in [1]) between a
set of multivariate samples ``pred`` and a true data vector ``truth``. Running time
Expand Down Expand Up @@ -538,6 +540,8 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
The leftmost dim is that of the multivariate sample.
:param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except
for the second leftmost dim which can have any value or be omitted.
:param int pred_batch_size: If specified the predictions will be batched before calculation
according to the specified batch size in order to reduce memory consumption.
:return: A tensor of shape ``truth.shape``.
:rtype: torch.Tensor
"""
Expand All @@ -552,10 +556,44 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten
"Actual shapes: {} versus {}".format(pred.shape, truth.shape)
)

retval = (
torch.cdist(pred, truth).mean(dim=-2)
- 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None]
)
if pred_batch_size is None:
retval = (
torch.cdist(pred, truth).mean(dim=-2)
- 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None]
)
else:
# Divide predictions into batches
pred_len = pred.shape[-2]
pred_batches = []
while pred.numel() > 0:
pred_batches.append(pred[..., :pred_batch_size, :])
pred = pred[..., pred_batch_size:, :]
# Calculate predictions distance to truth
retval = (
torch.stack(
[
torch.cdist(pred_batch, truth).sum(dim=-2)
for pred_batch in pred_batches
],
dim=0,
).sum(dim=0)
/ pred_len
)
# Calculate predictions self distance
for aux_pred_batch in pred_batches:
retval = (
retval
- 0.5
* torch.stack(
[
torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2])
for pred_batch in pred_batches
],
dim=0,
).sum(dim=0)[..., None]
/ pred_len
/ pred_len
)

if remove_leftmost_dim:
retval = retval[..., 0]
Expand Down
23 changes: 23 additions & 0 deletions tests/ops/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,26 @@ def test_multivariate_energy_score(sample_dim, num_samples=10000):
rtol=0.02,
)
assert energy_score * 1.02 < energy_score_uncorrelated


@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)])
@pytest.mark.parametrize("sample_dim", [30, 100])
@pytest.mark.parametrize(
"num_samples, pred_batch_size", [(100, 10), (100, 30), (100, 100), (100, 200)]
)
def test_energy_score_empirical_batched_calculation(
batch_shape, sample_dim, num_samples, pred_batch_size
):
# Generate data
truth = torch.randn(batch_shape + (sample_dim,))
pred = torch.randn(batch_shape + (num_samples, sample_dim))
# Do batched and regular calculation
expected = energy_score_empirical(pred, truth)
actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size)
# Check accuracy
assert_close(actual, expected)


def test_jit_compilation():
# Test that functions can be JIT compiled
torch.jit.script(energy_score_empirical)

0 comments on commit 04c371f

Please sign in to comment.