Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equalize effect handler #3375

Merged
merged 4 commits into from
Jul 10, 2024
Merged

Equalize effect handler #3375

merged 4 commits into from
Jul 10, 2024

Conversation

BenZickel
Copy link
Contributor

@BenZickel BenZickel commented Jun 16, 2024

Given a stochastic function with some primitive statements and a list of names, the equalize effect handler forces the primitive statements at those names to have the same value, with that value being the result of the first primitive statement matching those names.

Consider the following Pyro program:

def per_category_model(category):
    shift = pyro.param(f'{category}_shift', torch.randn(1))
    mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1))
    std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1))
    return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))

Running the program for multiple categories can be done by

def model(categories):
    return {category:per_category_model(category) for category in categories}

To make the std sample sites have the same value, we can write

equal_std_model = pyro.poutine.equalize(model, '.+_std')

If on top of the above we would like to make the ‘shift’ parameters identical, we can write

equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean implementation.

From a software perspective this seems sensible, but from a modeling perspective I'm a bit worried that priors are double counted. I guess it's fine to trust the user.

@fritzo fritzo merged commit daea9a6 into pyro-ppl:dev Jul 10, 2024
9 checks passed
@BenZickel
Copy link
Contributor Author

Hi @fritzo, regarding your comment on double counting of priors, I think this is not happening, as only the first sample site is sampled from its distribution, while the remaining sites are converted to deterministic sample sites. Hopefully, this is what you meant in your comment.

@fritzo
Copy link
Member

fritzo commented Jul 10, 2024

Ah I see, so the prior for each equalized site will be the prior from first matching site. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants