Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble with basic parameter estimation with the Stable distribution #3280

Closed
fehiepsi opened this issue Oct 6, 2023 · 9 comments · Fixed by #3369
Closed

Trouble with basic parameter estimation with the Stable distribution #3280

fehiepsi opened this issue Oct 6, 2023 · 9 comments · Fixed by #3369

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Oct 6, 2023

[Moved the discussion in this forum thread to here.]

Doing MLE on the skewness parameter of Stable distribution does not recover the original parameter. The skewness tends to stay around 0 regardless initial values.

# adapted from https://github.com/fritzo/notebooks/blob/master/stable_mle.ipynb
import math
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.reparam import MinimalReparam
from pyro.infer.autoguide import AutoNormal, AutoGaussian

torch.set_default_dtype(torch.float64)
pyro.set_rng_seed(20230928)

# Define true parameters and number of datapoints
alpha = 1.5
beta = 0.8
c = 1.0
mu = 0.0
n = 10000

# sample data
data = dist.Stable(alpha, beta, c, mu).sample((n,))

@MinimalReparam()
def model(data):    
    alpha = 1.5  # pyro.param("alpha", torch.tensor(1.99), constraint=constraints.interval(0, 2))
    beta = pyro.param("beta", torch.tensor(0.5), constraint=constraints.interval(-1, 1))
    c = 1.0  # pyro.param("c", torch.tensor(1.0), constraint=constraints.positive)
    mu = 0.0  # pyro.param("mu", torch.tensor(0.0), constraint=constraints.real)
    with pyro.plate("data", data.shape[0]):
        pyro.sample("obs", dist.Stable(alpha, beta, c, mu), obs=data)

def train(model, guide, num_steps=1001, lr=0.1):
    pyro.clear_param_store()
    pyro.set_rng_seed(20230928)

    # set up ELBO, and optimizer
    elbo = Trace_ELBO()
    elbo.loss(model, guide, data=data)
    optim = pyro.optim.Adam({"lr": lr})
    svi = SVI(model, guide, optim, loss=elbo)

    # optimize
    losses = []
    for i in range(num_steps):
        loss = svi.step(data) / data.numel()
        losses.append(loss)
        if i % 100 == 0:
            print(f"step {i} loss = {loss:0.6g}")

    print(f"Parameter estimates (n = {n}):")
    print(f"beta: Estimate = {pyro.param('beta')}, true = {beta}")
    return losses

guide = AutoNormal(model)
train(model, guide);

gives us something like

step 0 loss = 57.484
step 100 loss = 2.82283
step 200 loss = 2.64291
step 300 loss = 2.57734
step 400 loss = 2.57825
step 500 loss = 2.55413
step 600 loss = 2.58021
step 700 loss = 2.55925
step 800 loss = 2.56627
step 900 loss = 2.55161
step 1000 loss = 2.54688
Parameter estimates (n = 10000):
beta: Estimate = -0.0007896144629401247, true = 0.8
@fehiepsi fehiepsi added the bug label Oct 6, 2023
@fritzo
Copy link
Member

fritzo commented Oct 6, 2023

See the possibly related #3214. I'll also take a look at this.

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 6, 2023

I tried using pytorch < 2.0 but still got the same issue. I'm double checking the math of StableReparam. So far, the _safe_shift/_unsafe_shift is unrelated because the stability is 1.5.

@fritzo
Copy link
Member

fritzo commented Oct 6, 2023

One possibility is that the reparametrizer is correct, but the mean-field AutoNormal variational posterior is really bad. We can't rely on the Bernstein-von-Mises theorem here because StableReparam introduces four new latent variables per data point, so there is never any concentration in the parameter space of those latent variables, and hence Gaussian guides don't get better with more data.

Here are some ideas we might try to improve the variational approximation:

  • try introducing a correlated guide (as in my notebook). This doesn't seem to help.
  • try using Beta posteriors for the latent uniform random variables and/or Gamma for the latent exponential variables. These might better stick to the constraint boundaries than AutoNormal.
  • Try amortizing with a richer model. These reparametrizers are well suited for amortization nn: datum -> params where we simply fit a curve in parameter space ranging from datum=-inf to datum=inf. In the case of StableReparam with and AutoNormal guide this would be a function R -> R^8 (or R -> R^4 x (0,inf)^4) that could be fit via independent splines. This trick won't really help the variational fit, but will speed up the variational fitting computation, so we could use it together with richer reparametrizations.

One way we could validate this hypothesis is to see if HMC recovers correct parameters. If so, that would imply the reparametrizers are correct, so the variational approximation is at fault.

@fritzo
Copy link
Member

fritzo commented Oct 6, 2023

Here's a notebook examining the posteriors over the latents introduced by StableReparam. Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail, for example:

image

This suggests the SVI estimates of parameters may be off.

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 7, 2023

I used a custom guide which includes auxiliary variables (so prior and guide cancel out) and fitted the skewness to maximize the (MC estimated) likelihood - but no luck. Doing some random things like detach abs(skewness) in the reparam implementation seems to help, so it could be that some grad is wrong (due to clipping or something). I'll fit a jax version tomorrow to see how things look like.

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 7, 2023

Turns out that there's something wrong with my understanding of auxiliary methods. Let's consider a simple example: A + X = Normal(0, 1) + Normal(0, scale) ~ Normal(0, sqrt(2)), where A is an auxiliary variable. It's clear that the expected scale is 1. In the following, I used 4 approaches:

  • without guide: the auxiliary turns out to be noise, hence X ~ Normal(0, sqrt(2)) + Normal(0, 1) = Normal(0, sqrt(3)) - and SVI gives us scale = sqrt(3).
  • with autodelta guide: scale ~ 0.02 - under the hood, I think the MAP auxiliary point tends to be near the data point, hence the scale leans to be very small.
  • with autonormal guide: scale ~ 1 which is what we want.
  • with custom guide with fixed scale: scale ~ 1.38
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
import optax

def model(data=None, scale_init=1.0):
    scale = numpyro.param("scale", scale_init, constraint=dist.constraints.positive)
    # jax.debug.print("scale={scale}", scale=scale)
    with numpyro.plate("N", data.shape[0] if data is not None else 10000):
        auxiliary = numpyro.sample("auxiliary", dist.Normal(0, 1))
        return numpyro.sample("obs", dist.Normal(auxiliary, scale), obs=data)

data = numpyro.handlers.seed(model, rng_seed=0)(scale_init=1.0)
print("Data std:", jnp.std(data))
svi = SVI(model, lambda _: None, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoDelta', svi_results.params['scale'])

svi = SVI(model, AutoNormal(model), optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoNormal', svi_results.params['scale'])

def guide(data=None, scale_init=1.0):
    with numpyro.plate("N", data.shape[0]):
        loc = numpyro.param("loc", jnp.zeros_like(data))
        numpyro.sample("auxiliary", dist.Normal(loc, 1))

svi = SVI(model, guide, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using CustomGuide', svi_results.params['scale'])

gives us

Data std: 1.4132578
scale using no guide 1.7334453
scale using AutoDelta 0.02048848
scale using AutoNormal 1.0833057
scale using CustomGuide 1.3865443

So I think to make auxiliary method work, we need to know the geometry of the posterior of auxiliary variables. In the stable case, the posterior of auxiliary variables zu, ze, tu, te are likely non-gaussians. I'm not sure what is a good approach here.

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 7, 2023

Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail

Nice, @fritzo! I understand your comment better now after playing with the above toy example. So we need to use a more sophisticated guide here?

@fehiepsi fehiepsi added discussion and removed bug labels Oct 7, 2023
@fritzo
Copy link
Member

fritzo commented Oct 7, 2023

@fehiepsi correct, we need a more sophisticated guide. I just played around with Pyro's conditional normalizing flows, but I haven't gotten anything working yet 😞 I can't even seem to get an amortized diagonal normal guide working (the bottom of this notebook).

@mawright
Copy link

mawright commented Oct 11, 2023

I missed that you guys moved to Github. Repeating what I wrote on the Pyro board, I implemented a log_prob() for Stable that seems to work pretty well. I have a similar MLE example in [this notebook]{https://github.com/mawright/torchstable/blob/main/stable_demo.ipynb). It seems to be working well:

image

After some testing it seems to get a little unstable near $\alpha=1$ (e.g., for a true value of 1.1 it converges to 1.15 with noiseless data), which I'm looking into fixing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants