Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Fixing dtype mismatch in HMC/NUTS mass matrices (#1789)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1789

Previously, the mass matrices and some of the NUTS hyperparameters are not instantiated according to the dtype of the input, which can be problematic if the input tensors have double as dtype, or if the input tensor are on GPU.

Reviewed By: feynmanliang

Differential Revision: D40788741

fbshipit-source-id: 56e7918b8c4e3279c70674066242b673725fcbab
  • Loading branch information
horizon-blue authored and facebook-github-bot committed Oct 29, 2022
1 parent fc06b00 commit d222e69
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
4 changes: 1 addition & 3 deletions src/beanmachine/ppl/inference/proposer/hmc_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def __init__(
self.adapt_step_size = adapt_step_size
self.adapt_mass_matrix = adapt_mass_matrix
# we need mass matrix adapter to sample momentums
self._mass_matrix_adapter = MassMatrixAdapter(
len(self._positions), full_mass_matrix
)
self._mass_matrix_adapter = MassMatrixAdapter(self._positions, full_mass_matrix)
if self.adapt_step_size:
self.step_size = self._find_reasonable_step_size(
torch.as_tensor(initial_step_size),
Expand Down
4 changes: 2 additions & 2 deletions src/beanmachine/ppl/inference/proposer/hmc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ class MassMatrixAdapter:
https://mc-stan.org/docs/2_26/reference-manual/hmc-algorithm-parameters.html#euclidean-metric
"""

def __init__(self, matrix_size: int, full_mass_matrix: bool = False):
def __init__(self, initial_positions: torch.Tensor, full_mass_matrix: bool = False):
# inverse mass matrices, aka the inverse "metric"
self.mass_inv = torch.ones(matrix_size)
self.mass_inv = torch.ones_like(initial_positions)
# distribution objects for generating momentums
self.momentum_dist: dist.Distribution = dist.Normal(0.0, self.mass_inv)
if full_mass_matrix:
Expand Down
7 changes: 4 additions & 3 deletions src/beanmachine/ppl/inference/proposer/nuts_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,18 @@ def propose(self, world: World) -> Tuple[World, torch.Tensor]:
log_slice = -current_energy
else:
# this is a more stable way to sample from log(Uniform(0, exp(-current_energy)))
log_slice = torch.log1p(-torch.rand(())) - current_energy
log_slice = torch.log1p(-torch.rand_like(current_energy)) - current_energy
tree_node = _TreeNode(self._positions, momentums, self._pe_grad)
tree = _Tree(
left=tree_node,
right=tree_node,
proposal=self._positions,
pe=self._pe,
pe_grad=self._pe_grad,
log_weight=torch.tensor(0.0), # log accept prob of staying at current state
# log accept prob of staying at current state
log_weight=torch.zeros_like(log_slice),
sum_momentums=momentums,
sum_accept_prob=torch.tensor(0.0),
sum_accept_prob=torch.zeros_like(log_slice),
num_proposals=torch.tensor(0),
turned_or_diverged=torch.tensor(False),
)
Expand Down
3 changes: 2 additions & 1 deletion tests/ppl/inference/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def init_to_zero(d: dist.Distribution):
@pytest.mark.parametrize(
"algorithm",
[
bm.GlobalNoUTurnSampler(),
bm.GlobalNoUTurnSampler(full_mass_matrix=False),
bm.GlobalNoUTurnSampler(full_mass_matrix=True),
bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0),
bm.SingleSiteAncestralMetropolisHastings(),
bm.SingleSiteNewtonianMonteCarlo(),
Expand Down
2 changes: 1 addition & 1 deletion tests/ppl/inference/proposer/hmc_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_mass_matrix_adapter(full_mass_matrix):
positions_dict = RealSpaceTransform(world, world.latent_nodes)(dict(world))
dict2vec = DictToVecConverter(positions_dict)
positions = dict2vec.to_vec(positions_dict)
mass_matrix_adapter = MassMatrixAdapter(len(positions), full_mass_matrix)
mass_matrix_adapter = MassMatrixAdapter(positions, full_mass_matrix)
momentums = mass_matrix_adapter.initialize_momentums(positions)
assert isinstance(momentums, torch.Tensor)
assert momentums.shape == positions.shape
Expand Down

0 comments on commit d222e69

Please sign in to comment.