Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User contributed tutorial: Bayesian Hierarchical Stacking case study #1161

Merged
merged 13 commits into from
Oct 11, 2021

Conversation

MarcoGorelli
Copy link
Contributor

@MarcoGorelli MarcoGorelli commented Sep 24, 2021

Hi,

With reference to #189 , I figured I'd try submitting an implementation I recently did of the first case study from the Bayesian Hierarchical Stacking paper.

I'm aware that the issue reads

If you are interested, please open up an issue describing your model so that one of the core contributors can help assess feasibility, put into place any features that we are currently lacking, and work with you on the PR

and that I haven't opened an issue first asking if this would be welcome. However, I'd already made the notebook, and so figured the fastest way to move the conversation forward would be by opening a PR directly.
No hard feelings / offense taken if this isn't welcome of course 😄

Full disclosure: I've also submitted a PR with a PyMC3 version of this case study to the pymc-examples repo


This took about 40s to run on my laptop

@MarcoGorelli MarcoGorelli force-pushed the bayesian-hierarchical-stacking branch 3 times, most recently from dcdaafd to a2c3693 Compare September 24, 2021 12:10
@MarcoGorelli MarcoGorelli force-pushed the bayesian-hierarchical-stacking branch from a2c3693 to 3400cc8 Compare September 24, 2021 12:14
@MarcoGorelli MarcoGorelli changed the title User contributed example: Bayesian Hierarchical Stacking case study User contributed tutorial: Bayesian Hierarchical Stacking case study Sep 24, 2021
@martinjankowiak
Copy link
Collaborator

hi @MarcoGorelli thanks for the contribution! i'm sure this will be welcome. that note is mostly about trying to save the PR writer time (e.g. if a similar example already exists and they are unaware of it; if they are unaware of helper utilities that can simplify the code; etc)

@martinjankowiak
Copy link
Collaborator

here are some more detailed comments/suggestions:

  • can you make the "here" a link instead of a link to a link at the end?
  • could you add a bit more detail to the docstring for bs? in particular inputs/outputs?
  • can you be more specific about what you mean by a "stacking dataset"? this might be a bit confusing
  • a bit more explanation of stacking might be helpful; e.g. add comments to code?
  • can you add some comments to cells 20-21-22?
  • can you please make the plot in cell 23 more legible? maybe a violin plot would be better?

@MarcoGorelli MarcoGorelli marked this pull request as draft September 24, 2021 20:36
@MarcoGorelli MarcoGorelli marked this pull request as ready for review September 25, 2021 09:49
@martinjankowiak
Copy link
Collaborator

looks good to me pending any further comments from @fehiepsi

@fehiepsi
Copy link
Member

fehiepsi commented Sep 28, 2021

The tutorial is great. Thanks, @MarcoGorelli! I just have some comments:

  • probably set progress_bar=False to cleanup the widget output
  • section 3.2: use Pred_{i,k} rather than Pred_{i, k} i.e. no space between i and k. Same for W.
  • use the convention jnp rather than jax.numpy
  • vectorize the vstack logic (in jax, it is better to avoid for loop). For example, replace
jax.numpy.vstack(
                [jax.numpy.matmul(X, beta[k]) for k in range(K - 1)],
            ).T

by

beta @ X.T

and some hstack logic

  • as a tutorial, it is better to use .expand(...) rather than using sample shape, for example
beta_con = numpyro.sample(
        "beta_con",
        dist.Normal(0, 1),
        sample_shape=(K - 1, d - d_discrete),
    )

can be written as

beta_con = numpyro.sample(
        "beta_con",
        dist.Normal(0, 1).expand([K - 1, d - d_discrete])
    )

And it is best to use plate instead of explicitly define batch shapes at each site.

  • Not important - you can use a more numerical version of
    w = numpyro.deterministic("w", jax.nn.softmax(f, axis=1))

    # log probability of LOO training scores weighted by stacking weights.
    logp = jax.numpy.log((exp_lpd_point * w).sum(axis=1))

with

log_w = jax.nn.log_softmax(f, axis=1)
w = numpyro.deterministic("w", jnp.exp(log_w))
logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)
  • use 1 line for
    numpyro.deterministic(
        "w_test",
        w_test,
    )
  • use different seeds in each code cell. If a seed is used in for k in ... loop, you might want to use
    jax.random.fold_in(jax.random.PRNGKey(current_seed), k)
  • you can add thumbnail to your tutorial by putting the image here

@MarcoGorelli
Copy link
Contributor Author

Hey @fehiepsi and @martinjankowiak ,

thank you for your awesome reviews, much appreciated, I really learned a lot!

I'm off on holiday tomorrow for 2 weeks, so I'll respond to further feedback when I'm back

@fehiepsi
Copy link
Member

fehiepsi commented Oct 4, 2021

Thank you @MarcoGorelli! Please take your time. I just have two comments left:

  • It is better to add a boolean argument test=False to the model and only set it to true when evaluating:
if test:
    # test set stacking weights (in unconstrained space)
    f_test = jnp.hstack([X_test @ beta.T, jnp.zeros((N_test, 1))])
    # test set stacking weights (constrained to sum to 1)
    w_test = numpyro.deterministic("w_test", jax.nn.softmax(f_test, axis=1))
    numpyro.deterministic("w_test", w_test)
  • Getting different random keys in jax is a bit non-trivial. Typically, many jax devs use the following pattern
# in a code cell
rng_key, subkey = random.split(rng_key)
# then use subkey in the stochastic program of this cell

For example,

rng_key, subkey = random.split(rng_key)
for k in range(K):
    key = jax.random.fold_in(subkey, k)

Anyway, for simplicity, I think you can revise the rng_key pattern in cells 13, 17, 19, 20 using the following way

jax.random.fold_in(cell_number, k)  # where cell_number is 13/19/20

and

jax.random.PRNGKey(17)   # in cell 17

You can replace cell numbers by 0, 1, 2, 3, or any of your preferences.

@fehiepsi
Copy link
Member

Thanks for contributing the excellent tutorial, @MarcoGorelli !! I have wanted to learn stacking for a long time.

@fehiepsi fehiepsi merged commit 4ec9fc7 into pyro-ppl:master Oct 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants