Skip to content

Commit

Permalink
Fix pyro parameter nodes not showing up in graph of pyro deterministi…
Browse files Browse the repository at this point in the history
…c nodes [bugfix] (#3378)

* Fixes pyro parameter nodes not showing up in graph of pyro deterministic nodes.

* Linting and formatting.

* Ignore existing object redefinitions as they are now checked by ruff.

---------

Co-authored-by: Ben Zickel <[email protected]>
  • Loading branch information
BenZickel and Ben Zickel authored Jun 27, 2024
1 parent 64e71ee commit 1d7da54
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def __init__(
self.transforms = transforms

@constraints.dependent_property(event_dim=2)
def support(self):
def support(self): # noqa: F811
return constraints.independent(self.observation_dist.support, 1)

def expand(self, batch_shape, _instance=None):
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/stable_log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def set_integrator(num_points):

# Stub which is replaced by the default integrator when called for the first time
# if a default integrator has not already been set.
def integrate(*args, **kwargs):
def integrate(*args, **kwargs): # noqa: F811
set_integrator(num_points=501)
return integrate(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _get_type_from_frozenname(frozen_name):

sample_param[name] = [
upstream
for upstream in get_provenance(site["fn"].log_prob(site["value"]))
for upstream in provenance
if upstream != name and _get_type_from_frozenname(upstream) == "param"
]

Expand Down
10 changes: 10 additions & 0 deletions tests/infer/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def guide(zdim=1, scale=1.0):
scale=scale,
)
assert k > krange[0] and k < krange[1]


def test_render_model_deterministic_param():
def model():
value = pyro.param("param", torch.tensor(0.0))
pyro.deterministic("deterministic", value)

graph = pyro.render_model(model, render_params=True, render_deterministic=True)

assert "\tparam -> deterministic\n" in graph.body

0 comments on commit 1d7da54

Please sign in to comment.