Skip to content

Commit

Permalink
Minor fixes in tutorials (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmheiko authored Apr 16, 2024
1 parent b54af4a commit ab04d4f
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions notebooks/tutorial_part2_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
"As mentioned above, we call a promitive program without any observe statements a *kernel program*.\n",
"In combinators each primitive program denotes two densities:\n",
"1. a **prior density**, which is defined as the joint density over all unobserverd variables in the program\n",
"2. an **unnormalized target density**, which is defined as the joint density over all variables in the program\n",
"2. an **unnormalized target density**, which is defined as the prior density multiplied by the product over the densities of the observed variables in the program\n",
"\n",
"To get a better understanding of these densities and why their distinction is important, let's visualize these densities for the primitive program `f` that we defined above:\n",
"1. The prior density is given by the denstity of the normal distribution\n",
Expand Down Expand Up @@ -326,7 +326,7 @@
"_, f_batch_trace, f_batch_metrics = traced_evaluate(\n",
" numpyro.plate(\"particle_plate\", 10000)(f), seed=0\n",
")()\n",
"approx_target_sampels = f_batch_trace[\"x\"][\"value\"]\n",
"approx_target_samples = f_batch_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(f_batch_metrics[\"log_weight\"])\n",
"\n",
"print(\"Normalizing constant:\", Z_target)\n",
Expand Down Expand Up @@ -369,7 +369,7 @@
" color=\"C1\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -564,16 +564,14 @@
" m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n",
")\n",
"handles, labels = ax_xy.get_legend_handles_labels()\n",
"handles.extend(\n",
" [\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n",
" ),\n",
" ]\n",
")\n",
"handles.extend([\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n",
" ),\n",
"])\n",
"ax_xy.legend(handles=handles, loc=\"lower left\");"
]
},
Expand Down Expand Up @@ -772,7 +770,7 @@
"_, q2_trace, q2_metrics = traced_evaluate(q2, seed=0)()\n",
"_, _, f_batch_metrics = traced_evaluate(f_batch, seed=0)()\n",
"\n",
"approx_target_sampels = q2_trace[\"x\"][\"value\"]\n",
"approx_target_samples = q2_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(q2_metrics[\"log_weight\"])\n",
"weights_prior = np.exp(f_batch_metrics[\"log_weight\"])\n",
"ess = q2_metrics[\"ess\"]\n",
Expand Down Expand Up @@ -806,7 +804,7 @@
" color=\"C2\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -874,7 +872,7 @@
"source": [
"q3 = coix.resample(q2)\n",
"_, q3_trace, q3_metrics = traced_evaluate(q3, seed=0)()\n",
"approx_target_sampels = q3_trace[\"x\"][\"value\"]\n",
"approx_target_samples = q3_trace[\"x\"][\"value\"]\n",
"weights = jnp.exp(q3_metrics[\"log_weight\"])\n",
"print(\"The log weights after resampling are all equal:\", weights)\n",
"\n",
Expand All @@ -885,7 +883,7 @@
" color=\"C1\",\n",
")\n",
"_ = plt.hist(\n",
" approx_target_sampels,\n",
" approx_target_samples,\n",
" weights=weights,\n",
" density=True,\n",
" bins=100,\n",
Expand Down Expand Up @@ -968,22 +966,20 @@
" m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n",
")\n",
"handles, labels = ax_xy.get_legend_handles_labels()\n",
"handles.extend(\n",
" [\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n",
" ),\n",
" lines.Line2D(\n",
" [0],\n",
" [0],\n",
" label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n",
" color=\"C1\",\n",
" ),\n",
" ]\n",
")\n",
"handles.extend([\n",
" lines.Line2D(\n",
" [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n",
" ),\n",
" lines.Line2D(\n",
" [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n",
" ),\n",
" lines.Line2D(\n",
" [0],\n",
" [0],\n",
" label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n",
" color=\"C1\",\n",
" ),\n",
"])\n",
"ax_xy.legend(handles=handles, loc=\"lower left\")\n",
"\n",
"_, f_ext_trace, f_ext_metrics = traced_evaluate(\n",
Expand Down Expand Up @@ -1014,7 +1010,7 @@
"source": [
"### Takeaway\n",
"\n",
"We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted sampels for the target densities they define.\n",
"We are now ready to start combining programs using inference combinators and as long as we follow the rules of the grammar the resulting programs are valid, in the sense that they produce propoerly weighted samples for the target densities they define.\n",
"\n",
"To ensure that all evaluations are properly weighted, more general programs are more restricted in the ways they can be combined with other programs. If in doubt, check the grammar!\n",
"\n",
Expand Down

0 comments on commit ab04d4f

Please sign in to comment.