Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow returning deterministic sites from guide in Predictive #3361

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,14 @@ def _predictive_sequential(
)
collected_trace.append(trace)
collected_samples.append(
{site: trace.nodes[site]["value"] for site in return_site_shapes}
{
site: (
trace.nodes[site]["value"]
if site in trace.nodes
else samples[i][site]
)
for site in return_site_shapes
}
)

return _predictiveResults(
Expand All @@ -84,6 +91,7 @@ def _predictive(
model_args=(),
model_kwargs={},
mask=True,
posterior_deterministic_sites=(),
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
Expand Down Expand Up @@ -122,6 +130,9 @@ def _predictive(
elif site not in posterior_samples:
return_site_shapes[site] = site_shape

for site in posterior_deterministic_sites:
return_site_shapes[site] = posterior_samples[site].shape

# handle _RETURN site
if return_sites is not None and "_RETURN" in return_sites:
value = model_trace.nodes["_RETURN"]["value"]
Expand All @@ -143,7 +154,10 @@ def _predictive(
).get_trace(*model_args, **model_kwargs)
predictions = {}
for site, shape in return_site_shapes.items():
value = trace.nodes[site]["value"]
if site in trace.nodes:
value = trace.nodes[site]["value"]
else:
value = reshaped_samples[site]
if site == "_RETURN" and shape is None:
predictions[site] = value
continue
Expand Down Expand Up @@ -179,6 +193,8 @@ class Predictive(torch.nn.Module):
:param bool parallel: predict in parallel by wrapping the existing model
in an outermost `plate` messenger. Note that this requires that the model has
all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`.
:param return_deterministic_guide_sites: include deterministic sites from the guide
in returned samples; this does not affect the returned trace.
"""

def __init__(
Expand All @@ -189,6 +205,7 @@ def __init__(
num_samples=None,
return_sites=(),
parallel=False,
return_deterministic_guide_sites=False,
):
super().__init__()
if posterior_samples is None:
Expand Down Expand Up @@ -231,6 +248,7 @@ def __init__(
self.guide = guide
self.return_sites = return_sites
self.parallel = parallel
self.return_deterministic_guide_sites = return_deterministic_guide_sites

def call(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -262,18 +280,37 @@ def forward(self, *args, **kwargs):
"""
posterior_samples = self.posterior_samples
return_sites = self.return_sites

guide_deterministic_sites = ()

if self.guide is not None:
# return all sites by default if a guide is provided.
return_sites = None if not return_sites else return_sites
posterior_samples = _predictive(
guide_pred_res = _predictive(
self.guide,
posterior_samples,
self.num_samples,
return_sites=None,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
).samples
)
posterior_samples = guide_pred_res.samples

if self.return_deterministic_guide_sites:
if isinstance(guide_pred_res, Trace):
guide_tr = guide_pred_res.trace
else:
guide_tr = guide_pred_res.trace[0]

guide_deterministic_sites = tuple(
name
for name, site in guide_tr.nodes.items()
if site["type"] == "sample"
if site["infer"].get("_deterministic")
if (return_sites is None or name in return_sites)
)

return _predictive(
self.model,
posterior_samples,
Expand All @@ -282,6 +319,7 @@ def forward(self, *args, **kwargs):
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
posterior_deterministic_sites=guide_deterministic_sites,
).samples

def get_samples(self, *args, **kwargs):
Expand Down
8 changes: 7 additions & 1 deletion pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
if site["type"] == "sample"
if site["infer"].get("is_auxiliary")
)
det_vars = set(
name
for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
if site["infer"].get("_deterministic")
)
model_vars = set(
name
for name, site in model_trace.nodes.items()
Expand All @@ -284,7 +290,7 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
warnings.warn(
"Found auxiliary vars in the model: {}".format(aux_vars & model_vars)
)
if not (guide_vars <= model_vars | aux_vars):
if not (guide_vars <= model_vars | aux_vars | det_vars):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo, it looks like deterministic sites are currently treated as (auxiliary) sample sites in the guide when comparing the model to the guide. Do you know if this is the intended behavior? I assumed it wasn't and sketched the following solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @OlaRonning, I'm unsure as I seldom use predictive, preferring custom predictive utilities.

@fehiepsi do you have an opinion? It would be nice to keep similar behavior in pyro and numpyro.

warnings.warn(
"Found non-auxiliary vars in guide but not model, "
"consider marking these infer={{'is_auxiliary': True}}:\n{}".format(
Expand Down
63 changes: 63 additions & 0 deletions tests/infer/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,69 @@ def model(y=None):
assert_close(actual["x3"].mean(), y, rtol=0.1)


@pytest.mark.parametrize("with_plate", [True, False])
@pytest.mark.parametrize("event_shape", [(), (2,)])
@pytest.mark.parametrize("return_deterministic_guide_sites", [True, False])
@pytest.mark.parametrize("return_sites", [[], ["x4"]])
@pytest.mark.parametrize("use_determinisitic_guide", [True, False])
def test_deterministic_guide_return(
with_plate,
event_shape,
return_deterministic_guide_sites,
return_sites,
use_determinisitic_guide,
):
def model(y=None):
with pyro.util.optional(pyro.plate("plate", 3), with_plate):
x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event())
x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape))

pyro.deterministic("x3", x2)
return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y)

def determinisitic_guide(y=None):
with pyro.util.optional(pyro.plate("plate", 3), with_plate):
x = pyro.sample("x", dist.Normal(0, 2).expand(event_shape).to_event())
x4 = pyro.deterministic("x4", x**2, event_dim=len(event_shape))

pyro.deterministic("x5", x4)

def non_determinisitic_guide(y=None):
with pyro.util.optional(pyro.plate("plate", 3), with_plate):
pyro.sample("x", dist.Normal(0, 2).expand(event_shape).to_event())

if use_determinisitic_guide:
guide = determinisitic_guide
else:
guide = non_determinisitic_guide

y = torch.tensor(4.0)
svi = SVI(model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO())
for i in range(100):
svi.step(y)

actual = Predictive(
model,
guide=guide,
num_samples=1000,
return_sites=return_sites,
return_deterministic_guide_sites=return_deterministic_guide_sites,
)()

if return_deterministic_guide_sites and use_determinisitic_guide:
assert "x4" in actual
assert_close(actual["x4"].mean(), y, rtol=0.1)
# When return_sites is empty, include all deterministic guide sites
if len(return_sites) == 0:
assert "x5" in actual
assert_close(actual["x5"].mean(), y, rtol=0.1)
else:
assert "x5" not in actual
else:
assert "x4" not in actual
assert "x5" not in actual


def test_get_mask_optimization():
def model():
x = pyro.sample("x", dist.Normal(0, 1))
Expand Down
Loading