diff --git a/docs/source/contrib.zuko.rst b/docs/source/contrib.zuko.rst
new file mode 100644
index 0000000000..c7f2dbe7e1
--- /dev/null
+++ b/docs/source/contrib.zuko.rst
@@ -0,0 +1,5 @@
+Zuko in Pyro
+============
+
+.. automodule:: pyro.contrib.zuko
+ :members:
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 82b70e684f..a5104fb9bc 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -45,6 +45,7 @@ Pyro Documentation
contrib.randomvariable
contrib.timeseries
contrib.tracking
+ contrib.zuko
Indices and tables
diff --git a/pyro/contrib/zuko.py b/pyro/contrib/zuko.py
new file mode 100644
index 0000000000..232b773389
--- /dev/null
+++ b/pyro/contrib/zuko.py
@@ -0,0 +1,81 @@
+# Copyright Contributors to the Pyro project.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+This file contains helpers to use `Zuko `_-based
+normalizing flows within Pyro piplines.
+
+Accompanying tutorials can be found at `tutorial/svi_flow_guide.ipynb` and
+`tutorial/vae_flow_prior.ipynb`.
+"""
+
+import torch
+from torch import Size, Tensor
+
+import pyro
+
+
+class ZukoToPyro(pyro.distributions.TorchDistribution):
+ r"""Wraps a Zuko distribution as a Pyro distribution.
+
+ If ``dist`` has an ``rsample_and_log_prob`` method, like Zuko's flows, it will be
+ used when sampling instead of ``rsample``. The returned log density will be cached
+ for later scoring.
+
+ :param dist: A distribution instance.
+ :type dist: torch.distributions.Distribution
+
+ .. code-block:: python
+
+ flow = zuko.flows.MAF(features=5)
+
+ # flow() is a torch.distributions.Distribution
+
+ dist = flow()
+ x = dist.sample((2, 3))
+ log_p = dist.log_prob(x)
+
+ # ZukoToPyro(flow()) is a pyro.distributions.Distribution
+
+ dist = ZukoToPyro(flow())
+ x = dist((2, 3))
+ log_p = dist.log_prob(x)
+
+ with pyro.plate("data", 42):
+ z = pyro.sample("z", dist)
+ """
+
+ def __init__(self, dist: torch.distributions.Distribution):
+ self.dist = dist
+ self.cache = {}
+
+ @property
+ def has_rsample(self) -> bool:
+ return self.dist.has_rsample
+
+ @property
+ def event_shape(self) -> Size:
+ return self.dist.event_shape
+
+ @property
+ def batch_shape(self) -> Size:
+ return self.dist.batch_shape
+
+ def __call__(self, shape: Size = ()) -> Tensor:
+ if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring
+ x, self.cache[x] = self.dist.rsample_and_log_prob(shape)
+ elif self.has_rsample:
+ x = self.dist.rsample(shape)
+ else:
+ x = self.dist.sample(shape)
+
+ return x
+
+ def log_prob(self, x: Tensor) -> Tensor:
+ if x in self.cache:
+ return self.cache[x]
+ else:
+ return self.dist.log_prob(x)
+
+ def expand(self, *args, **kwargs):
+ return ZukoToPyro(self.dist.expand(*args, **kwargs))
diff --git a/tests/contrib/test_zuko.py b/tests/contrib/test_zuko.py
new file mode 100644
index 0000000000..cee04c177b
--- /dev/null
+++ b/tests/contrib/test_zuko.py
@@ -0,0 +1,65 @@
+# Copyright Contributors to the Pyro project.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import pytest
+import torch
+
+import pyro
+from pyro.contrib.zuko import ZukoToPyro
+from pyro.infer import SVI, Trace_ELBO
+from pyro.optim import Adam
+
+
+@pytest.mark.parametrize("multivariate", [True, False])
+@pytest.mark.parametrize("rsample_and_log_prob", [True, False])
+def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool):
+ # Distribution
+ if multivariate:
+ normal = torch.distributions.MultivariateNormal
+ mu = torch.zeros(3)
+ sigma = torch.eye(3)
+ else:
+ normal = torch.distributions.Normal
+ mu = torch.zeros(())
+ sigma = torch.ones(())
+
+ dist = normal(mu, sigma)
+
+ if rsample_and_log_prob:
+
+ def dummy(self, shape):
+ x = self.rsample(shape)
+ return x, self.log_prob(x)
+
+ dist.rsample_and_log_prob = dummy.__get__(dist)
+
+ # Sample
+ x1 = pyro.sample("x1", ZukoToPyro(dist))
+
+ assert x1.shape == dist.event_shape
+
+ # Sample within plate
+ with pyro.plate("data", 4):
+ x2 = pyro.sample("x2", ZukoToPyro(dist))
+
+ assert x2.shape == (4, *dist.event_shape)
+
+ # SVI
+ def model():
+ pyro.sample("a", ZukoToPyro(dist))
+
+ with pyro.plate("data", 4):
+ pyro.sample("b", ZukoToPyro(dist))
+
+ def guide():
+ mu_ = pyro.param("mu", mu)
+ sigma_ = pyro.param("sigma", sigma)
+
+ pyro.sample("a", ZukoToPyro(normal(mu_, sigma_)))
+
+ with pyro.plate("data", 4):
+ pyro.sample("b", ZukoToPyro(normal(mu_, sigma_)))
+
+ svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
+ svi.step()
diff --git a/tutorial/source/index.rst b/tutorial/source/index.rst
index 5d4c0cc12c..442cb74878 100644
--- a/tutorial/source/index.rst
+++ b/tutorial/source/index.rst
@@ -97,6 +97,7 @@ List of Tutorials
jit
svi_horovod
svi_lightning
+ svi_flow_guide
.. toctree::
:maxdepth: 1
@@ -106,7 +107,8 @@ List of Tutorials
vae
ss-vae
cvae
- normalizing_flows_i
+ normalizing_flows_intro
+ vae_flow_prior
dmm
air
cevae
diff --git a/tutorial/source/normalizing_flows_i.ipynb b/tutorial/source/normalizing_flows_intro.ipynb
similarity index 99%
rename from tutorial/source/normalizing_flows_i.ipynb
rename to tutorial/source/normalizing_flows_intro.ipynb
index 87284ba4b4..3617edc98e 100644
--- a/tutorial/source/normalizing_flows_i.ipynb
+++ b/tutorial/source/normalizing_flows_intro.ipynb
@@ -4,10 +4,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Normalizing Flows - Introduction (Part 1)\n",
+ "# Normalizing Flows - Introduction\n",
+ "\n",
+ "This tutorial introduces Pyro's built-in normalizing flows. It is independent of most of Pyro, but users may want to read about distribution shapes in the [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html).\n",
+ "\n",
+ "> The development of Pyro's built-in flows has stopped in favor of external libraries, such as [Zuko](https://github.com/probabilists/zuko), [nflows](https://github.com/bayesiains/nflows), [normflows](https://github.com/VincentStimper/normalizing-flows) or [FlowTorch](https://flowtorch.ai/). Some of these libraries may have interfaces that are not directly compatible with Pyro. See the [SVI with flow guide](svi_flow_guide.ipynb) and [VAE with flow prior](vae_flow_prior.ipynb) tutorials for example usages of [Zuko](https://github.com/probabilists/zuko) within Pyro.\n",
"\n",
- "This tutorial introduces Pyro's normalizing flow library. It is independent of much of Pyro, but users may want to read about distribution shapes in the [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html).\n",
- " \n",
"## Introduction\n",
"\n",
"In standard probabilistic modeling practice, we represent our beliefs over unknown continuous quantities with simple parametric distributions like the normal, exponential, and Laplacian distributions. However, using such simple forms, which are commonly symmetric and unimodal (or have a fixed number of modes when we take a mixture of them), restricts the performance and flexibility of our methods. For instance, standard variational inference in the Variational Autoencoder uses independent univariate normal distributions to represent the variational family. The true posterior is neither independent nor normally distributed, which results in suboptimal inference and simplifies the model that is learnt. In other scenarios, we are likewise restricted by not being able to model multimodal distributions and heavy or light tails.\n",
@@ -25,8 +27,7 @@
" \n",
"Normalizing Flows are a family of methods for constructing flexible distributions. Let's first restrict our attention to representing univariate distributions. The basic idea is that a simple source of noise, for example a variable with a standard normal distribution, $X\\sim\\mathcal{N}(0,1)$, is passed through a bijective (i.e. invertible) function, $g(\\cdot)$ to produce a more complex transformed variable $Y=g(X)$.\n",
"\n",
- "For a given random variable, we typically want to perform two operations: sampling and scoring. Sampling $Y$ is trivial. First, we sample $X=x$, then calculate $y=g(x)$. Scoring $Y$, or rather, evaluating the log-density $\\log(p_Y(y))$, is more involved. How does the density of $Y$ relate to the density of $X$? We can use the substitution rule of integral calculus to answer this. Suppose we want to evaluate the expectation of some function of $X$. Then,\n",
- "\n",
+ "For a given random variable, we typically want to perform two operations: sampling and scoring. Sampling $Y$ is trivial. First, we sample $X=x$, then calculate $y=g(x)$. Scoring $Y$, or rather, evaluating the log-density $\\log p_Y(y)$, is more involved. How does the density of $Y$ relate to the density of $X$? We can use the substitution rule of integral calculus to answer this. Suppose we want to evaluate the expectation of some function of $X$. Then,\n",
"\n",
"\\begin{align}\n",
"\\mathbb{E}_{p_X(\\cdot)}\\left[f(X)\\right] &= \\int_{\\text{supp}(X)}f(x)p_X(x)dx\\\\\n",
@@ -34,29 +35,22 @@
"&= \\mathbb{E}_{p_Y(\\cdot)}\\left[f(g^{-1}(Y))\\right],\n",
"\\end{align}\n",
"\n",
- "\n",
"where $\\text{supp}(X)$ denotes the support of $X$, which in this case is $(-\\infty,\\infty)$. Crucially, we used the fact that $g$ is bijective to apply the substitution rule in going from the first to the second line. Equating the last two lines we get,\n",
"\n",
- "\n",
"\\begin{align}\n",
- "\\log(p_Y(y)) &= \\log(p_X(g^{-1}(y)))+\\log\\left(\\left|\\frac{dx}{dy}\\right|\\right)\\\\\n",
- "&= \\log(p_X(g^{-1}(y)))-\\log\\left(\\left|\\frac{dy}{dx}\\right|\\right).\n",
+ "\\log p_Y(y) & = \\log p_X(g^{-1}(y)) + \\log\\left|\\frac{dx}{dy}\\right| \\\\\n",
+ "& = \\log p_X(g^{-1}(y)) - \\log\\left|\\frac{dy}{dx}\\right|.\n",
"\\end{align}\n",
"\n",
- "\n",
"Inituitively, this equation says that the density of $Y$ is equal to the density at the corresponding point in $X$ plus a term that corrects for the warp in volume around an infinitesimally small length around $Y$ caused by the transformation.\n",
"\n",
- "If $g$ is cleverly constructed (and we will see several examples shortly), we can produce distributions that are more complex than standard normal noise and yet have easy sampling and computationally tractable scoring. Moreover, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have $L$ transforms $g_{(0)}, g_{(1)},\\ldots,g_{(L-1)}$, then the log-density of the transformed variable $Y=(g_{(0)}\\circ g_{(1)}\\circ\\cdots\\circ g_{(L-1)})(X)$ is\n",
- "\n",
+ "If $g$ is cleverly constructed (and we will see several examples shortly), we can produce distributions that are more complex than standard normal noise and yet have easy sampling and computationally tractable scoring. Moreover, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have a sequence of $L$ transforms $(g_1, g_2, \\ldots, g_L)$ such that $Y = g(X) = g_L \\circ \\cdots g_2 \\circ g_1(X)$, then the log-density of $Y$ is\n",
"\n",
"\\begin{align}\n",
- "\\log(p_Y(y)) &= \\log\\left(p_X\\left(\\left(g_{(L-1)}^{-1}\\circ\\cdots\\circ g_{(0)}^{-1}\\right)\\left(y\\right)\\right)\\right)+\\sum^{L-1}_{l=0}\\log\\left(\\left|\\frac{dg^{-1}_{(l)}(y_{(l)})}{dy'}\\right|\\right),\n",
- "%\\left( g^{(l)}(y^{(l)})\n",
- "%\\right).\n",
+ "\\log p_Y(y) = \\log p_X(y_0) + \\sum^{L}_{l=1} \\log \\left| \\frac{dg^{-1}_{l}(y_l)}{dy_{l}} \\right|\n",
"\\end{align}\n",
"\n",
- "\n",
- "where we've defined $y_{(0)}=x$, $y_{(L-1)}=y$ for convenience of notation.\n",
+ "where $y_{l} = y$ and $y_{l-1} = g^{-1}_l(y_{l})$.\n",
"\n",
"In a latter section, we will see how to generalize this method to multivariate $X$. The field of Normalizing Flows aims to construct such $g$ for multivariate $X$ to transform simple i.i.d. standard normal noise into complex, learnable, high-dimensional distributions. The methods have been applied to such diverse applications as image modeling, text-to-speech, unsupervised language induction, data compression, and modeling molecular structures. As probability distributions are the most fundamental component of probabilistic modeling we will likely see many more exciting state-of-the-art applications in the near future."
]
@@ -71,13 +65,11 @@
"\n",
"Let us begin by showing how to represent and manipulate a simple transformed distribution,\n",
"\n",
- "\n",
"\\begin{align}\n",
"X &\\sim \\mathcal{N}(0,1)\\\\\n",
"Y &= \\text{exp}(X).\n",
"\\end{align}\n",
"\n",
- "\n",
"You may have recognized that this is by definition, $Y\\sim\\text{LogNormal}(0,1)$.\n",
"\n",
"We begin by importing the relevant libraries:"
@@ -122,14 +114,12 @@
"source": [
"The class [ExpTransform](https://pytorch.org/docs/master/distributions.html#torch.distributions.transforms.ExpTransform) derives from [Transform](https://pytorch.org/docs/master/distributions.html#torch.distributions.transforms.Transform) and defines the forward, inverse, and log-absolute-derivative operations for this transform,\n",
"\n",
- "\n",
"\\begin{align}\n",
"g(x) &= \\text{exp(x)}\\\\\n",
"g^{-1}(y) &= \\log(y)\\\\\n",
- "\\log\\left(\\left|\\frac{dg}{dx}\\right|\\right) &= x.\n",
+ "\\log \\left|\\frac{dg}{dx}\\right| &= x.\n",
"\\end{align}\n",
"\n",
- "\n",
"In general, a transform class defines these three operations, from which it is sufficient to perform sampling and scoring.\n",
"\n",
"The class [TransformedDistribution](https://pytorch.org/docs/master/distributions.html#torch.distributions.transformed_distribution.TransformedDistribution) takes a base distribution of simple noise and a list of transforms, and encapsulates the distribution formed by applying these transformations in sequence. We use it as:"
@@ -183,13 +173,11 @@
"source": [
"Our example uses a single transform. However, we can compose transforms to produce more expressive distributions. For instance, if we apply an affine transformation we can produce the general log-normal distribution,\n",
"\n",
- "\n",
"\\begin{align}\n",
"X &\\sim \\mathcal{N}(0,1)\\\\\n",
"Y &= \\text{exp}(\\mu+\\sigma X).\n",
"\\end{align}\n",
"\n",
- "\n",
"or rather, $Y\\sim\\text{LogNormal}(\\mu,\\sigma^2)$. In Pyro this is accomplished, e.g. for $\\mu=3, \\sigma=0.5$, as follows:"
]
},
@@ -282,13 +270,13 @@
"plt.show()\n",
"\n",
"plt.subplot(1, 2, 1)\n",
- "sns.distplot(X[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X[:,0], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2})\n",
"plt.title(r'$p(x_1)$')\n",
"plt.subplot(1, 2, 2)\n",
- "sns.distplot(X[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X[:,1], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2})\n",
@@ -356,7 +344,7 @@
" loss.backward()\n",
" optimizer.step()\n",
" flow_dist.clear_cache()\n",
- " \n",
+ "\n",
" if step % 200 == 0:\n",
" print('step: {}, loss: {}'.format(step, loss.item()))"
]
@@ -407,24 +395,24 @@
"plt.show()\n",
"\n",
"plt.subplot(1, 2, 1)\n",
- "sns.distplot(X[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X[:,0], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,0], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='flow')\n",
"plt.title(r'$p(x_1)$')\n",
"plt.subplot(1, 2, 2)\n",
- "sns.distplot(X[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X[:,1], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,1], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
@@ -452,36 +440,30 @@
"\n",
"Sampling $Y$ is again trivial and involves evaluation of the forward pass of $g$. We can score $Y$ using the multivariate substitution rule of integral calculus,\n",
"\n",
- "\n",
"\\begin{align}\n",
"\\mathbb{E}_{p_X(\\cdot)}\\left[f(X)\\right] &= \\int_{\\text{supp}(X)}f(\\mathbf{x})p_X(\\mathbf{x})d\\mathbf{x}\\\\\n",
- "&= \\int_{\\text{supp}(Y)}f(g^{-1}(\\mathbf{y}))p_X(g^{-1}(\\mathbf{y}))\\det\\left|\\frac{d\\mathbf{x}}{d\\mathbf{y}}\\right|d\\mathbf{y}\\\\\n",
+ "&= \\int_{\\text{supp}(Y)}f(g^{-1}(\\mathbf{y}))p_X(g^{-1}(\\mathbf{y}))\\left|\\det\\frac{d\\mathbf{x}}{d\\mathbf{y}}\\right|d\\mathbf{y}\\\\\n",
"&= \\mathbb{E}_{p_Y(\\cdot)}\\left[f(g^{-1}(Y))\\right],\n",
"\\end{align}\n",
"\n",
- "\n",
- "where $d\\mathbf{x}/d\\mathbf{y}$ denotes the Jacobian matrix of $g^{-1}(\\mathbf{y})$. Equating the last two lines we get,\n",
- "\n",
+ "where $\\det \\frac{d\\mathbf{x}}{d\\mathbf{y}}$ denotes the determinant of the Jacobian matrix of $g^{-1}(\\mathbf{y})$. Equating the last two lines we get,\n",
"\n",
"\\begin{align}\n",
- "\\log(p_Y(y)) &= \\log(p_X(g^{-1}(y)))+\\log\\left(\\det\\left|\\frac{d\\mathbf{x}}{d\\mathbf{y}}\\right|\\right)\\\\\n",
- "&= \\log(p_X(g^{-1}(y)))-\\log\\left(\\det\\left|\\frac{d\\mathbf{y}}{d\\mathbf{x}}\\right|\\right).\n",
+ "\\log p_Y(y) &= \\log p_X(g^{-1}(y)) + \\log\\left|\\det\\frac{d\\mathbf{x}}{d\\mathbf{y}}\\right|\\\\\n",
+ "&= \\log p_X(g^{-1}(y)) - \\log\\left|\\det\\frac{d\\mathbf{y}}{d\\mathbf{x}}\\right|.\n",
"\\end{align}\n",
"\n",
"Inituitively, this equation says that the density of $Y$ is equal to the density at the corresponding point in $X$ plus a term that corrects for the warp in volume around an infinitesimally small volume around $Y$ caused by the transformation. For instance, in $2$-dimensions, the geometric interpretation of the absolute value of the determinant of a Jacobian is that it represents the area of a parallelogram with edges defined by the columns of the Jacobian. In $n$-dimensions, the geometric interpretation of the absolute value of the determinant Jacobian is that is represents the hyper-volume of a parallelepiped with $n$ edges defined by the columns of the Jacobian (see a calculus reference such as \\[7\\] for more details).\n",
"\n",
- "Similar to the univariate case, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have $L$ transforms $g_{(0)}, g_{(1)},\\ldots,g_{(L-1)}$, then the log-density of the transformed variable $Y=(g_{(0)}\\circ g_{(1)}\\circ\\cdots\\circ g_{(L-1)})(X)$ is\n",
- "\n",
+ "Similar to the univariate case, we can compose such bijective transformations to produce even more complex distributions. By an inductive argument, if we have a sequence of $L$ transforms $(g_1, g_2, \\ldots, g_L)$ such that $Y = g(X) = g_L \\circ \\cdots g_2 \\circ g_1(X)$, then the log-density of $Y$ is\n",
"\n",
"\\begin{align}\n",
- "\\log(p_Y(y)) &= \\log\\left(p_X\\left(\\left(g_{(L-1)}^{-1}\\circ\\cdots\\circ g_{(0)}^{-1}\\right)\\left(y\\right)\\right)\\right)+\\sum^{L-1}_{l=0}\\log\\left(\\left|\\frac{dg^{-1}_{(l)}(y_{(l)})}{dy'}\\right|\\right),\n",
- "%\\left( g^{(l)}(y^{(l)})\n",
- "%\\right).\n",
+ "\\log p_Y(y) = \\log p_X(y_0) + \\sum^{L}_{l=1} \\log \\left| \\det \\frac{dg^{-1}_{l}(y_l)}{dy_{l}} \\right|\n",
"\\end{align}\n",
"\n",
- "where we've defined $y_{(0)}=x$, $y_{(L-1)}=y$ for convenience of notation.\n",
+ "where $y_{l} = y$ and $y_{l-1} = g^{-1}_l(y_{l})$.\n",
"\n",
- "The main challenge is in designing parametrizable multivariate bijections that have closed form expressions for both $g$ and $g^{-1}$, a tractable Jacobian whose calculation scales with $O(D)$ rather than $O(D^3)$, and can express a flexible class of functions."
+ "The main challenge is in designing parametrizable multivariate bijections that have closed form expressions for both $g$ and $g^{-1}$, a tractable Jacobian determinant whose calculation scales with $O(D)$ rather than $O(D^3)$, and can express a flexible class of functions."
]
},
{
@@ -571,7 +553,7 @@
" loss.backward()\n",
" optimizer.step()\n",
" flow_dist.clear_cache()\n",
- " \n",
+ "\n",
" if step % 500 == 0:\n",
" print('step: {}, loss: {}'.format(step, loss.item()))"
]
@@ -613,24 +595,24 @@
"plt.show()\n",
"\n",
"plt.subplot(1, 2, 1)\n",
- "sns.distplot(X[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X[:,0], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,0], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='flow')\n",
"plt.title(r'$p(x_1)$')\n",
"plt.subplot(1, 2, 2)\n",
- "sns.distplot(X[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X[:,1], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,1], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
@@ -799,7 +781,7 @@
" optimizer.step()\n",
" dist_x1.clear_cache()\n",
" dist_x2_given_x1.clear_cache()\n",
- " \n",
+ "\n",
" if step % 500 == 0:\n",
" print('step: {}, loss: {}'.format(step, loss.item()))"
]
@@ -845,24 +827,24 @@
"plt.show()\n",
"\n",
"plt.subplot(1, 2, 1)\n",
- "sns.distplot(X[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X[:,0], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,0], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,0], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='flow')\n",
"plt.title(r'$p(x_1)$')\n",
"plt.subplot(1, 2, 2)\n",
- "sns.distplot(X[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X[:,1], hist=False, kde=True,\n",
" bins=None,\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
" label='data')\n",
- "sns.distplot(X_flow[:,1], hist=False, kde=True, \n",
+ "sns.distplot(X_flow[:,1], hist=False, kde=True,\n",
" bins=None, color='firebrick',\n",
" hist_kws={'edgecolor':'black'},\n",
" kde_kws={'linewidth': 2},\n",
@@ -897,13 +879,6 @@
"9. Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. [*Density estimation using Real-NVP*](https://arxiv.org/abs/1605.08803). Conference paper at ICLR 2017.\n",
"10. David Ha, Andrew Dai, Quoc V. Le. [*HyperNetworks*](https://arxiv.org/abs/1609.09106). Workshop contribution at ICLR 2017."
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/tutorial/source/svi_flow_guide.ipynb b/tutorial/source/svi_flow_guide.ipynb
new file mode 100644
index 0000000000..4b0fe1c89d
--- /dev/null
+++ b/tutorial/source/svi_flow_guide.ipynb
@@ -0,0 +1,238 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# SVI with a Normalizing Flow guide\n",
+ "\n",
+ "Thanks to their expressiveness, normalizing flows (see [normalizing flow introduction](normalizing_flows_intro.ipynb)) are great guide candidates for stochastic variational inference (SVI). This notebook demonstrates how to perform amortized SVI with a normalizing flow as guide.\n",
+ "\n",
+ "> In this notebook we use [Zuko](https://zuko.readthedocs.io/) to implement normalizing flows, but similar results can be obtained with other PyTorch-based flow libraries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pyro\n",
+ "import torch\n",
+ "import zuko # pip install zuko\n",
+ "\n",
+ "from corner import corner, overplot_points # pip install corner\n",
+ "from pyro.contrib.zuko import ZukoToPyro\n",
+ "from pyro.optim import ClippedAdam\n",
+ "from pyro.infer import SVI, Trace_ELBO\n",
+ "from torch import Tensor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model\n",
+ "\n",
+ "We define a simple non-linear model $p(x | z)$ with a standard Gaussian prior $p(z)$ over the latent variables $z$."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prior = pyro.distributions.Normal(torch.zeros(3), torch.ones(3)).to_event(1)\n",
+ "\n",
+ "def likelihood(z: Tensor):\n",
+ " mu = z[..., :2]\n",
+ " rho = z[..., 2].tanh() * 0.99\n",
+ "\n",
+ " cov = 1e-2 * torch.stack([\n",
+ " torch.ones_like(rho), rho,\n",
+ " rho, torch.ones_like(rho),\n",
+ " ], dim=-1).unflatten(-1, (2, 2))\n",
+ "\n",
+ " return pyro.distributions.MultivariateNormal(mu, cov)\n",
+ "\n",
+ "def model(x: Tensor):\n",
+ " with pyro.plate(\"data\", x.shape[1]):\n",
+ " z = pyro.sample(\"z\", prior)\n",
+ "\n",
+ " with pyro.plate(\"obs\", 5):\n",
+ " pyro.sample(\"x\", likelihood(z), obs=x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We sample 64 reference latent variables and observations $(z^*, x^*)$. In practice, $z^*$ is unknown, and $x^*$ is your data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "z_star = prior.sample((64,))\n",
+ "x_star = likelihood(z_star).sample((5,))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Guide\n",
+ "\n",
+ "We define the guide $q_\\phi(z | x)$ with a normalizing flow. We choose a conditional [neural spline flow](https://arxiv.org/abs/1906.04032) borrowed from the [Zuko](https://zuko.readthedocs.io/) library. Because Zuko distributions are very similar to Pyro distributions, a thin wrapper (`ZukoToPyro`) is sufficient to make Zuko and Pyro 100% compatible."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "flow = zuko.flows.NSF(features=3, context=10, transforms=1, hidden_features=(256, 256))\n",
+ "flow.transform = flow.transform.inv # inverse autoregressive flow (IAF) are fast to sample from\n",
+ "\n",
+ "def guide(x: Tensor):\n",
+ " pyro.module(\"flow\", flow)\n",
+ "\n",
+ " with pyro.plate(\"data\", x.shape[1]): # amortized\n",
+ " pyro.sample(\"z\", ZukoToPyro(flow(x.transpose(0, 1).flatten(-2))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## SVI\n",
+ "\n",
+ "We train our guide with a standard stochastic variational inference (SVI) pipeline. We use 16 particles to reduce the variance of the ELBO and clip the norm of the gradients to make training more stable."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(0) 209195.08367919922\n",
+ "(256) -25.225540161132812\n",
+ "(512) -99.09033203125\n",
+ "(768) -102.66302490234375\n",
+ "(1024) -138.8058319091797\n",
+ "(1280) -92.15625\n",
+ "(1536) -136.78167724609375\n",
+ "(1792) -87.76119995117188\n",
+ "(2048) -116.21714782714844\n",
+ "(2304) -162.0266571044922\n",
+ "(2560) -91.13175964355469\n",
+ "(2816) -164.86270141601562\n",
+ "(3072) -98.17607116699219\n",
+ "(3328) -102.58432006835938\n",
+ "(3584) -151.61912536621094\n",
+ "(3840) -77.94436645507812\n",
+ "(4096) -121.82719421386719\n"
+ ]
+ }
+ ],
+ "source": [
+ "pyro.clear_param_store()\n",
+ "\n",
+ "svi = SVI(model, guide, optim=ClippedAdam({\"lr\": 1e-3, \"clip_norm\": 10.0}), loss=Trace_ELBO(num_particles=16, vectorize_particles=True))\n",
+ "\n",
+ "for step in range(4096 + 1):\n",
+ " elbo = svi.step(x_star)\n",
+ "\n",
+ " if step % 256 == 0:\n",
+ " print(f'({step})', elbo)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Posterior predictive"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "