Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gmm and dmm examples #26

Merged
merged 8 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,32 @@

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

Coix is a lightweight framework which includes the following main components:

- **coix.api:** Implementation of the program combinators.
- **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program.
- **coix.loss:** Common objectives for variational inference.
- **coix.algo:** Example inference algorithms.

Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility.

*This is not an officially supported Google product.*

## Installation

To install Coix, you can use pip:

```
pip install coix
```

or you can clone the repository:

```
git clone https://github.com/jax-ml/coix.git
cd coix
pip install -e .[dev,doc]
```

Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support.

5 changes: 5 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ def wrapped(*args, **kwargs):


def empirical(out, trace, metrics):
"""Creates an empirical program given a trace."""
return get_backend()["empirical"](out, trace, metrics)


def suffix(p):
"""Adds suffix `_PREV_` to variable names of `p`."""
fn = get_backend()["suffix"]
if fn is not None:
return fn(p)
Expand All @@ -149,6 +151,7 @@ def suffix(p):


def detach(p):
"""Makes random variables in `p` become non-reparameterized."""
fn = get_backend()["detach"]
if fn is not None:
return fn(p)
Expand All @@ -157,6 +160,7 @@ def detach(p):


def stick_the_landing(p):
"""Stops gradient of distributions' parameters before computing log prob."""
fn = get_backend()["stick_the_landing"]
if fn is not None:
return fn(p)
Expand All @@ -165,6 +169,7 @@ def stick_the_landing(p):


def prng_key():
"""Generates a random JAX PRNGKey."""
fn = get_backend()["prng_key"]
if fn is not None:
return fn()
Expand Down
9 changes: 7 additions & 2 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def wrapped(*args, **kwargs):
if site["type"] == "sample":
value = site["value"]
log_prob = site["fn"].log_prob(value)
trace[name] = {"value": value, "log_prob": log_prob}
event_dim_holder = jnp.empty([1] * site["fn"].event_dim)
trace[name] = {
"value": value,
"log_prob": log_prob,
"_event_dim_holder": event_dim_holder,
}
if site.get("is_observed", False):
trace[name]["is_observed"] = True
metrics = {
Expand Down Expand Up @@ -83,7 +88,7 @@ def wrapped(*args, **kwargs):
del args, kwargs
for name, site in trace.items():
value, lp = site["value"], site["log_prob"]
event_dim = jnp.ndim(value) - jnp.ndim(lp)
event_dim = jnp.ndim(site["_event_dim_holder"])
obs = value if "is_observed" in site else None
numpyro.sample(name, dist.Delta(value, lp, event_dim=event_dim), obs=obs)
for name, value in metrics.items():
Expand Down
Binary file added docs/_static/anneal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/anneal_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bmnist.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/dmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/dmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 6 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@
):
toctree_path = "notebooks/" if src_file.endswith("ipynb") else "examples/"
filename = os.path.splitext(src_file.split("/")[-1])[0]
png_path = "_static/" + filename + ".png"
img_path = "_static/" + filename + ".png"
# use Coix logo if not exist png file
if not os.path.exists(png_path):
png_path = "_static/coix_logo.png"
nbsphinx_thumbnails[toctree_path + filename] = png_path
if not os.path.exists(img_path):
img_path = "_static/" + filename + ".gif"
if not os.path.exists(img_path):
img_path = "_static/coix_logo.png"
nbsphinx_thumbnails[toctree_path + filename] = img_path


# -- Options for HTML output -------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Coix Documentation
examples/gmm
examples/dmm
examples/bmnist
examples/anneal_oryx
examples/gmm_oryx
examples/dmm_oryx
examples/anneal_oryx

Indices and tables
==================
Expand Down
5 changes: 4 additions & 1 deletion examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.

.. image:: ../_static/anneal.png
:align: center

"""

import argparse
Expand Down Expand Up @@ -199,7 +202,7 @@ def eval_program(seed):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
7 changes: 5 additions & 2 deletions examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.

.. image:: ../_static/anneal_oryx.png
:align: center

"""

import argparse
Expand Down Expand Up @@ -119,7 +122,7 @@ def __call__(self, x):
def anneal_target(network, key, k=0):
key_out, key = random.split(key)
x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
coix.factor(network.anneal_density(x, index=k), name="anneal_density")
coryx.factor(network.anneal_density(x, index=k), name="anneal_density")
return key_out, {"x": x}


Expand Down Expand Up @@ -192,7 +195,7 @@ def main(args):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
Loading
Loading