From 4688dd0a480697e4b7eaf1df2453390f28034c2d Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 24 Sep 2024 20:07:24 +0300 Subject: [PATCH 1/2] Add batched calculation option to energy_score_empirical in order to reduce memory consumption. --- pyro/ops/stats.py | 49 ++++++++++++++++++++++++++++++++++++----- tests/ops/test_stats.py | 18 +++++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index a0a546059a..fcb102655e 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -3,7 +3,7 @@ import math import numbers -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch.fft import irfft, rfft @@ -510,7 +510,9 @@ def crps_empirical(pred, truth): return (pred - truth).abs().mean(0) - (diff * weight).sum(0) / num_samples**2 -def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor: +def energy_score_empirical( + pred: torch.Tensor, truth: torch.Tensor, pred_batch_size: Optional[int] = None +) -> torch.Tensor: r""" Computes negative Energy Score ES* (see equation 22 in [1]) between a set of multivariate samples ``pred`` and a true data vector ``truth``. Running time @@ -538,6 +540,8 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten The leftmost dim is that of the multivariate sample. :param torch.Tensor truth: A tensor of true observations with same shape as ``pred`` except for the second leftmost dim which can have any value or be omitted. + :param int pred_batch_size: If specified the predictions will be batched before calculation + according to the specified batch size in order to reduce memory consumption. :return: A tensor of shape ``truth.shape``. :rtype: torch.Tensor """ @@ -552,10 +556,43 @@ def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Ten "Actual shapes: {} versus {}".format(pred.shape, truth.shape) ) - retval = ( - torch.cdist(pred, truth).mean(dim=-2) - - 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None] - ) + if pred_batch_size is None: + retval = ( + torch.cdist(pred, truth).mean(dim=-2) + - 0.5 * torch.cdist(pred, pred).mean(dim=[-1, -2])[..., None] + ) + else: + # Divide predictions into batches + pred_len = pred.shape[-2] + pred_batches = [] + while pred.numel() > 0: + pred_batches.append(pred[..., :pred_batch_size, :]) + pred = pred[..., pred_batch_size:, :] + # Calculate predictions distance to truth + retval = ( + torch.cat( + [ + torch.cdist(pred_batch, truth).sum(dim=-2, keepdim=True) + for pred_batch in pred_batches + ], + dim=-2, + ).sum(dim=-2) + / pred_len + ) + # Calculate predictions self distance + for aux_pred_batch in pred_batches: + retval = ( + retval + - 0.5 + * sum( # type: ignore[index] + [ + torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2]) + for pred_batch in pred_batches + ] + )[..., None] + / pred_len + / pred_len + ) if remove_leftmost_dim: retval = retval[..., 0] diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index 41f7ba3c8c..d52ea0f6f2 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -355,3 +355,21 @@ def test_multivariate_energy_score(sample_dim, num_samples=10000): rtol=0.02, ) assert energy_score * 1.02 < energy_score_uncorrelated + + +@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)]) +@pytest.mark.parametrize("sample_dim", [30, 100]) +@pytest.mark.parametrize( + "num_samples, pred_batch_size", [(100, 10), (100, 30), (100, 100), (100, 200)] +) +def test_energy_score_empirical_batched_calculation( + batch_shape, sample_dim, num_samples, pred_batch_size +): + # Generate data + truth = torch.randn(batch_shape + (sample_dim,)) + pred = torch.randn(batch_shape + (num_samples, sample_dim)) + # Do batched and regular calculation + expected = energy_score_empirical(pred, truth) + actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size) + # Check accuracy + assert_close(actual, expected) From 979bafafed3915e0688516f14f4f35d4cc482ed9 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Thu, 26 Sep 2024 09:40:55 +0300 Subject: [PATCH 2/2] Replace native Python sum with torch stack(...).sum(). --- pyro/ops/stats.py | 15 ++++++++------- tests/ops/test_stats.py | 5 +++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index fcb102655e..73f6054e5a 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -570,13 +570,13 @@ def energy_score_empirical( pred = pred[..., pred_batch_size:, :] # Calculate predictions distance to truth retval = ( - torch.cat( + torch.stack( [ - torch.cdist(pred_batch, truth).sum(dim=-2, keepdim=True) + torch.cdist(pred_batch, truth).sum(dim=-2) for pred_batch in pred_batches ], - dim=-2, - ).sum(dim=-2) + dim=0, + ).sum(dim=0) / pred_len ) # Calculate predictions self distance @@ -584,12 +584,13 @@ def energy_score_empirical( retval = ( retval - 0.5 - * sum( # type: ignore[index] + * torch.stack( [ torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2]) for pred_batch in pred_batches - ] - )[..., None] + ], + dim=0, + ).sum(dim=0)[..., None] / pred_len / pred_len ) diff --git a/tests/ops/test_stats.py b/tests/ops/test_stats.py index d52ea0f6f2..cae8ef5aba 100644 --- a/tests/ops/test_stats.py +++ b/tests/ops/test_stats.py @@ -373,3 +373,8 @@ def test_energy_score_empirical_batched_calculation( actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size) # Check accuracy assert_close(actual, expected) + + +def test_jit_compilation(): + # Test that functions can be JIT compiled + torch.jit.script(energy_score_empirical)