Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for deterministic dependent samples in PyroSample [enhancement] #3376

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class PyroSample:
assert isinstance(my_module, PyroModule)
my_module.x = PyroSample(Normal(0, 1)) # independent
my_module.y = PyroSample(lambda self: Normal(self.x, 1)) # dependent
my_module.z = PyroSample(lambda self: self.y ** 2) # deterministic dependent

or EXPERIMENTALLY as a decorator on lazy initialization methods::

Expand All @@ -175,16 +176,22 @@ def x(self):
def y(self):
return Normal(self.x, 1) # dependent

@PyroSample
def z(self):
return self.y ** 2 # deterministic dependent

def forward(self):
return self.y # accessed like a @property
return self.z # accessed like a @property

:param prior: distribution object or function that inputs the
:class:`PyroModule` instance ``self`` and returns a distribution
object.
object or a deterministic value.
"""

prior: Union[
"TorchDistributionMixin", Callable[["PyroModule"], "TorchDistributionMixin"]
"TorchDistributionMixin",
Callable[["PyroModule"], "TorchDistributionMixin"],
Callable[["PyroModule"], torch.Tensor],
]

def __post_init__(self) -> None:
Expand Down Expand Up @@ -605,13 +612,17 @@ def __getattr__(self, name: str) -> Any:
if value is None:
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
value = pyro.sample(fullname, prior)
value = (
pyro.deterministic(fullname, prior)
if isinstance(prior, torch.Tensor)
else pyro.sample(fullname, prior)
)
context.set(fullname, value)
return value
else: # Cannot determine supermodule and hence cannot compute fullname.
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
return prior()
return prior if isinstance(prior, torch.Tensor) else prior()

result = super().__getattr__(name)

Expand Down
24 changes: 21 additions & 3 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,10 @@ def __init__(self, size):
)
self.s = PyroSample(dist.Normal(0, 1))
self.t = PyroSample(lambda self: dist.Normal(self.s, self.z))
self.u = PyroSample(lambda self: self.t**2)

def forward(self):
return self.x + self.y + self.t
return self.x + self.y + self.u


class DecoratorModel(PyroModule):
Expand Down Expand Up @@ -521,8 +522,12 @@ def s(self):
def t(self):
return dist.Normal(self.s, self.z).to_event(1)

@PyroSample
def u(self):
return self.t**2

def forward(self):
return self.x + self.y + self.t
return self.x + self.y + self.u


@pytest.mark.parametrize("Model", [AttributeModel, DecoratorModel])
Expand All @@ -531,19 +536,32 @@ def test_decorator(Model, size):
model = Model(size)
for i in range(2):
trace = poutine.trace(model).get_trace()
assert set(trace.nodes.keys()) == {"_INPUT", "x", "y", "z", "s", "t", "_RETURN"}
assert set(trace.nodes.keys()) == {
"_INPUT",
"x",
"y",
"z",
"s",
"t",
"u",
"_RETURN",
}

assert trace.nodes["x"]["type"] == "param"
assert trace.nodes["y"]["type"] == "param"
assert trace.nodes["z"]["type"] == "param"
assert trace.nodes["s"]["type"] == "sample"
assert trace.nodes["t"]["type"] == "sample"
assert trace.nodes["u"]["type"] == "sample"

assert trace.nodes["x"]["value"].shape == (size,)
assert trace.nodes["y"]["value"].shape == (size,)
assert trace.nodes["z"]["value"].shape == (size,)
assert trace.nodes["s"]["value"].shape == ()
assert trace.nodes["t"]["value"].shape == (size,)
assert trace.nodes["u"]["value"].shape == (size,)

assert trace.nodes["u"]["infer"] == {"_deterministic": True}


def test_mixin_factory():
Expand Down
Loading