-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose some mcmc states for sequential sampling strategy #861
Conversation
Thanks for your suggestions, Martin! I will need to discuss with Neeraj about the API. Especially, I am not sure if #781 is related and can be resolved here. |
numpyro/infer/mcmc.py
Outdated
def init_state(self): | ||
""" | ||
The initial state of the MCMC chain. If this attribute is None, | ||
:meth:`run` will call `self.sampler.init(...)` method for initialization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern with exposing this is that it might give the impression that users can use this to specify initial value of parameters or mass matrix (which isn't true). However, there are other auxiliary variables like z_grad
and those mentioned in the warning that are non-obvious and would result in leakage of internal details. Could you elaborate on the use case for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this warning is tricky to deliver to users. The motivation here is to start the chain on new data with the previously adapted mass matrix, step_size, and the last sample, as requested by @martinjankowiak and in #534.
For those who want such a warm start feature, they can still achieve that by modifying the warmup state by state._replace(i=0, pe=1e6)
and taking care of other related diagnostics information (wrong z_grad
is fine because z_grad
on the old data only drives momentum on the first leapfrog step - but a wrong pe might lead to rejecting all proposals). So I think we should not expose it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The motivation here is to start the chain on new data with the previously adapted mass matrix, step_size, and the last sample
Do you mean starting from these values and running adaptation from there? IIUC there are three things that we can specify for initialization - initial values (which is possible using init_to_value
?), step size and mass matrix. This method is to be able to expose the latter two. Is that correct? Even then, I suppose the adaptation will re-learn fresh values for the latter two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will re-learn them but started from the previous values. This might be helpful for tricky posteriors, which require spending much time on initial steps...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm..how about things like the adaptation window index, wouldn't that need to be reset?
Another choice might be to have something like MCMCKernel.get_init_state(self, *args, **kwargs)
method. For HMC, it could be get_init_state(self, init_params, mass_matrix=None, step_size=None)
which can internally do all this book-keeping (e.g. setting i to 0). That might be a bit more verbose but easier to explain since we can tell the users to use the get_init_state
method of the corresponding kernel without other caveats.
kernel = HMC(...)
mcmc = MCMC(kernel, ...)
mcmc.initial_state = kernel.get_init_state(init_params, mass_matrix)
mcmc.run(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me. We also need args, kwargs so it is more like kernel.init
method (except that we won't try to find valid initial params). We have discussed how to expose the inverse mass matrix to the API in #536 but found that it is a bit ambiguous. I think we should skip this feature unless there are more requests for it. What do you think?
how about things like the adaptation window index
We use hmc_state.i for the update so it should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I think it is fine to discuss a bit more before finalizing this. Power users can already do this if they need to, but we should provide an intuitive interface if we are going to advertise this.
numpyro/infer/mcmc.py
Outdated
self._init_state = state | ||
|
||
@property | ||
def warmup_state(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are to expose this, what do you think about calling this post_warmup_state
for clarity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by post
? Does it mean posterior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant post as in after
as opposed to pre. e.g. post-warmup draws
from pystan. Other suggestions are welcome too, I just think warmup_state
is fine for internal usage but isn't descriptive enough to be exposed in the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no preference so I'll make the change. Thanks for your suggestion!
To resolve this, I think we'll need to separately collect warmup samples and have a |
I think I am better at understanding that issue now. It is not about how to collect those warmup samples, but about having a more intuitive API to do that job. We can do it later if it is needed. :) Your suggestions there make sense to me. |
Resolves #534. This is also requested by @rexdouglass in #539
This is just a solution, mainly for discussion. I'm not sure if this is a good API so if reviewers have other ideas, please let me know.