-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Stable distribution with numerically integrated log-probability calculation (StableWithLogProb). #3369
Conversation
…tion of the Stable distribution.
…n order to improve convergence.
…the first time in order to avoid requiring scipy when building the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great to see this implemented!
I don't trust my review of the math, but I would trust some sort of density test. One option is to use goftests.density_goodness_of_fit, something like:
@pytest.mark.parametrize(...)
def test_density(stability, skew, loc, scale):
d = StableWithLogProb(stability, skew, loc, scale)
samples = d.sample(1000)
probs = d.log_prob(samples).exp()
gof = goftests.density_goodness_of_fit(samples, probs)
assert gof > 1e-2
Another option is to check against a reference implementation, say something in scipy. WDYT?
(btw thanks for your patience!)
beta = self.skew.double() | ||
value = value.double() | ||
|
||
return _stable_log_prob(alpha, beta, value, self.coords) - self.scale.log() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll want to convert the result of _stable_log_prob()
back to value.dtype
, right? Something like:
logp = _stable_log_prob(alpha, beta, value, self.coords)
return logp.to(dtype=value.dtype) - self.scale.log()
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
I think it would actually be cleaner to implement Stable.log_prob()
rather than a separate StableWithLogProb
class (but thank you for drafting a non-invasive solution!). Do you see any blockers to simply merging Stable <-> StableWithLogProb in this PR? I think the only change will be the need to update your tutorial's summary:
## Summary
-- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is undefined.
+- [Stable.log_prob()](http://docs.pyro.ai/en/stable/distributions.html#stable) is very expensive.
and simply omit the reparam stuff from your new section. The single Stable
solution is nice in that users will at least be able to use default SVI and HMC, and the older StableReparam machinery can become an approximate cost-saving tool.
EDIT I guess we'd need to revise some pytest.raises
checks in the tests, which might be easiest by adding an internal distribution pyro.distributions.testing.fakes.StableWithoutLogProb
.
BTW I am a big fan of the Levy Stable distribution, and am delighted to see Pyro improving its support for heavy-tailed inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to merge as-is, but I think this .log_prob()
deserves to be in the main Stable
distribution. LMK if you want me to merge now, or if you want to combine Stable
<-> StableWithLogProbs
.
Thx @fritzo for the review! I also think heavy-tailed inference is much needed and I really appreciate all the work done on this so far. It might be better to combine |
One more option that comes to mind is to keep both |
This fixes #3280 by adding
pyro.distributions.StableWithLogProb
which is based onpyro.distributions.Stable
with an additionallog_prob
method (I opted for not modifying thepyro.distributions.Stable
distribution at this stage).Code is based on combining #3280 (comment) by @mawright with the existing Stable distribution Pyro code base, with the following modifications:
alpha
value of one, and values at and near zero.torchquad
package.torchquad
does this but overall speed is 25% faster than the reference implementation based ontorchquad
).Per iteration duration is about 5 times slower than with reparameterization but overall convergence is much faster, and includes cases which do not converge with reparameterization (like skew
beta
estimation).The log-probability calculation is based on integration over a uniformly distributed random variable$u$ such that $P(x) = \int du P(x|u) P(u)$ . The integral can be converted to a reparameterization where we first sample $u$ with probability density $P(u)$ or $g(u)$ when approximating the posterior distribution by a guide, and secondly sampling or observing $x$ with the distribution $P(x|u)$ . Initial tests indicate this reparameterization works but is still slower than estimating the log-probability by integration.
A usage example with real life data has been added to the last section of the Stable distribution tutorial.