Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential extension to compute_log_probs to facilitate VI model diagnostics #1939

Open
hessammehr opened this issue Dec 19, 2024 · 0 comments
Labels
question Further information is requested

Comments

@hessammehr
Copy link
Contributor

It often happens that you want to diagnose a VI fit, specifically examining how well the guide fits the prior, the data, etc. So far, I've been using a function like the following but would be interested to know if there are better established alternatives and, if not, whether it would be an appropriate as a backwards compatible extension to the newly introduced (and very useful) compute_log_probs function (or perhaps as a separate function).

def compute_log_probs(
    model,
    model_args: tuple,
    model_kwargs: dict,
    model_params: dict,
    guide=None,
    guide_params:dict=None,
    sum_log_prob: bool = True,
):
    from numpyro.infer.util import compute_log_probs as clp
    from numpyro.handlers import trace, replay, substitute
    if guide:
        guide_trace = trace(substitute(guide, guide_params or {})).get_trace(*model_args, **model_kwargs)
        model = replay(model, guide_trace)
    return clp(model, model_args, model_kwargs, model_params, sum_log_prob=sum_log_prob)
@fehiepsi fehiepsi added the question Further information is requested label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants