Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceEnum_ELBO: Subsample local variables that depend on a global model-enumerated variable #1572

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from functools import partial
from functools import partial, reduce
from operator import itemgetter
import warnings

Expand Down Expand Up @@ -959,33 +959,40 @@ def single_particle_elbo(rng_key):
*(frozenset(f.inputs) & group_plates for f in group_factors)
)
elim_plates = group_plates - outermost_plates
plate_to_scale = {}
for name in group_names:
for plate, value in (
model_trace[name].get("plate_to_scale", {}).items()
):
if plate in plate_to_scale:
if value != plate_to_scale[plate]:
raise ValueError(
"Expected all enumerated sample sites to share a common scale factor, "
f"but found different scales at plate('{plate}')."
)
else:
plate_to_scale[plate] = value
with funsor.interpretations.normalize:
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
plate_to_scale=plate_to_scale,
)
# TODO: add memoization
cost = funsor.optimizer.apply_optimizer(cost)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
scales_set = set()
for name in group_names | group_sum_vars:
site_scale = model_trace[name]["scale"]
if site_scale is None:
site_scale = 1.0
if isinstance(site_scale, jnp.ndarray):
raise ValueError(
"Enumeration only supports scalar handlers.scale"
)
scales_set.add(float(site_scale))
if len(scales_set) != 1:
raise ValueError(
"Expected all enumerated sample sites to share a common scale, "
f"but found {len(scales_set)} different scales."
)
scale = next(iter(scales_set))
scale = reduce(
funsor.ops.mul,
[
value
for plate, value in plate_to_scale.items()
if plate not in elim_plates
],
1.0,
)
# combine deps
deps = frozenset().union(
*[model_deps[name] for name in group_names]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"dev": [
"dm-haiku",
"flax",
"funsor>=0.4.1",
"funsor>=0.4.6",
"graphviz",
"jaxns>=2.0.1",
"matplotlib",
Expand Down
32 changes: 10 additions & 22 deletions test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2314,14 +2314,10 @@ def actual_loss_fn(params_raw):
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)
assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
Expand Down Expand Up @@ -2389,20 +2385,16 @@ def actual_loss_fn(params_raw):
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)
assert_equal(actual_loss, expected_loss, prec=1e-5)
assert_equal(actual_grads, expected_grads, prec=1e-5)


@pytest.mark.parametrize("scale", [1, 10])
def test_model_enum_subsample_3(scale):
# Enumerate: a
# Subsample: a, b, c
# Subsample: b, c
# [ a - [----> b ]
# [ \ [ ]
# [ - [- [-> c ] ]
Expand Down Expand Up @@ -2464,14 +2456,10 @@ def actual_loss_fn(params_raw):
}
return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params)

with pytest.raises(
ValueError, match="Expected all enumerated sample sites to share a common scale"
):
# This never gets run because we don't support this yet.
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)
assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)


def test_guide_plate_contraction():
Expand Down