-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trouble with basic parameter estimation with the Stable distribution #3280
Comments
See the possibly related #3214. I'll also take a look at this. |
I tried using pytorch < 2.0 but still got the same issue. I'm double checking the math of StableReparam. So far, the _safe_shift/_unsafe_shift is unrelated because the stability is 1.5. |
One possibility is that the reparametrizer is correct, but the mean-field Here are some ideas we might try to improve the variational approximation:
One way we could validate this hypothesis is to see if HMC recovers correct parameters. If so, that would imply the reparametrizers are correct, so the variational approximation is at fault. |
Here's a notebook examining the posteriors over the latents introduced by This suggests the SVI estimates of parameters may be off. |
I used a custom guide which includes auxiliary variables (so prior and guide cancel out) and fitted the skewness to maximize the (MC estimated) likelihood - but no luck. Doing some random things like detach abs(skewness) in the reparam implementation seems to help, so it could be that some grad is wrong (due to clipping or something). I'll fit a jax version tomorrow to see how things look like. |
Turns out that there's something wrong with my understanding of auxiliary methods. Let's consider a simple example: A + X = Normal(0, 1) + Normal(0, scale) ~ Normal(0, sqrt(2)), where A is an auxiliary variable. It's clear that the expected scale is 1. In the following, I used 4 approaches:
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
import optax
def model(data=None, scale_init=1.0):
scale = numpyro.param("scale", scale_init, constraint=dist.constraints.positive)
# jax.debug.print("scale={scale}", scale=scale)
with numpyro.plate("N", data.shape[0] if data is not None else 10000):
auxiliary = numpyro.sample("auxiliary", dist.Normal(0, 1))
return numpyro.sample("obs", dist.Normal(auxiliary, scale), obs=data)
data = numpyro.handlers.seed(model, rng_seed=0)(scale_init=1.0)
print("Data std:", jnp.std(data))
svi = SVI(model, lambda _: None, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoDelta', svi_results.params['scale'])
svi = SVI(model, AutoNormal(model), optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoNormal', svi_results.params['scale'])
def guide(data=None, scale_init=1.0):
with numpyro.plate("N", data.shape[0]):
loc = numpyro.param("loc", jnp.zeros_like(data))
numpyro.sample("auxiliary", dist.Normal(loc, 1))
svi = SVI(model, guide, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using CustomGuide', svi_results.params['scale']) gives us
So I think to make auxiliary method work, we need to know the geometry of the posterior of auxiliary variables. In the stable case, the posterior of auxiliary variables |
Nice, @fritzo! I understand your comment better now after playing with the above toy example. So we need to use a more sophisticated guide here? |
@fehiepsi correct, we need a more sophisticated guide. I just played around with Pyro's conditional normalizing flows, but I haven't gotten anything working yet 😞 I can't even seem to get an amortized diagonal normal guide working (the bottom of this notebook). |
I missed that you guys moved to Github. Repeating what I wrote on the Pyro board, I implemented a After some testing it seems to get a little unstable near |
[Moved the discussion in this forum thread to here.]
Doing MLE on the skewness parameter of Stable distribution does not recover the original parameter. The skewness tends to stay around 0 regardless initial values.
gives us something like
The text was updated successfully, but these errors were encountered: