From 1d7da542d309fb0d227d92d74d703866a3e95039 Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Thu, 27 Jun 2024 23:50:55 +0300 Subject: [PATCH] Fix pyro parameter nodes not showing up in graph of pyro deterministic nodes [bugfix] (#3378) * Fixes pyro parameter nodes not showing up in graph of pyro deterministic nodes. * Linting and formatting. * Ignore existing object redefinitions as they are now checked by ruff. --------- Co-authored-by: Ben Zickel --- pyro/distributions/hmm.py | 2 +- pyro/distributions/stable_log_prob.py | 2 +- pyro/infer/inspect.py | 2 +- tests/infer/test_util.py | 10 ++++++++++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pyro/distributions/hmm.py b/pyro/distributions/hmm.py index 9e5d714aa4..9f2a242682 100644 --- a/pyro/distributions/hmm.py +++ b/pyro/distributions/hmm.py @@ -1091,7 +1091,7 @@ def __init__( self.transforms = transforms @constraints.dependent_property(event_dim=2) - def support(self): + def support(self): # noqa: F811 return constraints.independent(self.observation_dist.support, 1) def expand(self, batch_shape, _instance=None): diff --git a/pyro/distributions/stable_log_prob.py b/pyro/distributions/stable_log_prob.py index c5c953c393..fa173e58f9 100644 --- a/pyro/distributions/stable_log_prob.py +++ b/pyro/distributions/stable_log_prob.py @@ -44,7 +44,7 @@ def set_integrator(num_points): # Stub which is replaced by the default integrator when called for the first time # if a default integrator has not already been set. -def integrate(*args, **kwargs): +def integrate(*args, **kwargs): # noqa: F811 set_integrator(num_points=501) return integrate(*args, **kwargs) diff --git a/pyro/infer/inspect.py b/pyro/infer/inspect.py index 88a722fd5d..2580301dba 100644 --- a/pyro/infer/inspect.py +++ b/pyro/infer/inspect.py @@ -331,7 +331,7 @@ def _get_type_from_frozenname(frozen_name): sample_param[name] = [ upstream - for upstream in get_provenance(site["fn"].log_prob(site["value"])) + for upstream in provenance if upstream != name and _get_type_from_frozenname(upstream) == "param" ] diff --git a/tests/infer/test_util.py b/tests/infer/test_util.py index 45f5807295..519fa6b251 100644 --- a/tests/infer/test_util.py +++ b/tests/infer/test_util.py @@ -72,3 +72,13 @@ def guide(zdim=1, scale=1.0): scale=scale, ) assert k > krange[0] and k < krange[1] + + +def test_render_model_deterministic_param(): + def model(): + value = pyro.param("param", torch.tensor(0.0)) + pyro.deterministic("deterministic", value) + + graph = pyro.render_model(model, render_params=True, render_deterministic=True) + + assert "\tparam -> deterministic\n" in graph.body