|
70 | 70 | {
|
71 | 71 | "cell_type": "code",
|
72 | 72 | "metadata": {
|
73 |
| - "id": "VQ1rLUzqha2i", |
74 |
| - "outputId": "462cff6e-ea94-48f5-a6b0-a8bd1f255bb4", |
75 |
| - "colab": { |
76 |
| - "base_uri": "https://localhost:8080/" |
77 |
| - } |
| 73 | + "id": "VQ1rLUzqha2i" |
78 | 74 | },
|
79 | 75 | "source": [
|
80 |
| - "!pip install -q git+https://github.com/ziatdinovmax/gpax.git" |
| 76 | + "!pip install -q gpax" |
81 | 77 | ],
|
82 |
| - "execution_count": 1, |
83 |
| - "outputs": [ |
84 |
| - { |
85 |
| - "output_type": "stream", |
86 |
| - "name": "stdout", |
87 |
| - "text": [ |
88 |
| - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", |
89 |
| - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", |
90 |
| - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", |
91 |
| - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m312.7/312.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
92 |
| - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m371.0/371.0 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", |
93 |
| - "\u001b[?25h Building wheel for gpax (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" |
94 |
| - ] |
95 |
| - } |
96 |
| - ] |
| 78 | + "execution_count": null, |
| 79 | + "outputs": [] |
97 | 80 | },
|
98 | 81 | {
|
99 | 82 | "cell_type": "markdown",
|
|
154 | 137 | "height": 449
|
155 | 138 | },
|
156 | 139 | "id": "-I4RQ2xCi0VV",
|
157 |
| - "outputId": "5fb08e30-3dab-4147-a1a1-93e23817b330" |
| 140 | + "outputId": "e80e0974-0e92-47fd-c1e0-27885f78b6c4" |
158 | 141 | },
|
159 | 142 | "source": [
|
160 | 143 | "np.random.seed(0)\n",
|
|
214 | 197 | "base_uri": "https://localhost:8080/"
|
215 | 198 | },
|
216 | 199 | "id": "c7kXm_lui6Dy",
|
217 |
| - "outputId": "0b260517-4273-4ce7-b04e-1911da22a3a5" |
| 200 | + "outputId": "c8daf638-fea1-420b-9c3b-099999352d6a" |
218 | 201 | },
|
219 | 202 | "source": [
|
220 | 203 | "# Get random number generator keys for training and prediction\n",
|
221 |
| - "rng_key, rng_key_predict = gpax.utils.get_keys()\n", |
| 204 | + "key1, key2 = gpax.utils.get_keys()\n", |
222 | 205 | "\n",
|
223 | 206 | "# Initialize model\n",
|
224 | 207 | "gp_model = gpax.ExactGP(1, kernel='RBF')\n",
|
225 | 208 | "# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
|
226 |
| - "gp_model.fit(rng_key, X, y, num_chains=1)" |
| 209 | + "gp_model.fit(key1, X, y, num_chains=1)" |
227 | 210 | ],
|
228 | 211 | "execution_count": 4,
|
229 | 212 | "outputs": [
|
230 | 213 | {
|
231 | 214 | "output_type": "stream",
|
232 | 215 | "name": "stderr",
|
233 | 216 | "text": [
|
234 |
| - "sample: 100%|██████████| 4000/4000 [00:14<00:00, 277.25it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n" |
| 217 | + "sample: 100%|██████████| 4000/4000 [00:10<00:00, 375.60it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n" |
235 | 218 | ]
|
236 | 219 | },
|
237 | 220 | {
|
|
263 | 246 | "\n",
|
264 | 247 | "$$𝜇^{post}_*= \\frac{1}{L} ∑_{i=1}^L 𝜇_*^i,$$\n",
|
265 | 248 | "\n",
|
266 |
| - "which corresponds to the ```y_pred``` in the code cell below, and\n", |
| 249 | + "which corresponds to the ```posterior_mean``` in the code cell below, and\n", |
267 | 250 | "samples\n",
|
268 | 251 | "\n",
|
269 | 252 | "$$f_*^i∼MVNormal(𝜇^i_*, 𝛴^i_*)$$\n",
|
270 | 253 | "\n",
|
271 |
| - "from multivariate normal distributions for all the pairs of predictive means and covariances (```y_sampled``` in the code cell below). Note that model noise is absorbed into the kernel computation function." |
| 254 | + "from multivariate normal distributions for all the pairs of predictive means and covariances (```f_samples``` in the code cell below). Note that model noise is absorbed into the kernel computation function." |
272 | 255 | ]
|
273 | 256 | },
|
274 | 257 | {
|
|
281 | 264 | "X_test = np.linspace(-1, 1, 100)\n",
|
282 | 265 | "# Get the GP prediction. Here n stands for the number of samples from each MVNormal distribution\n",
|
283 | 266 | "# (the total number of MVNormal distributions is equal to the number of HMC samples)\n",
|
284 |
| - "y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)" |
| 267 | + "posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)" |
285 | 268 | ],
|
286 | 269 | "execution_count": 5,
|
287 | 270 | "outputs": []
|
|
303 | 286 | "height": 449
|
304 | 287 | },
|
305 | 288 | "id": "lJIdx7fUnX-W",
|
306 |
| - "outputId": "29465923-a54e-44c3-ffe6-6d38d06dcbce" |
| 289 | + "outputId": "f8076d8c-cbc1-49de-9b71-2c6d6665bbb5" |
307 | 290 | },
|
308 | 291 | "source": [
|
309 | 292 | "_, ax = plt.subplots(dpi=100)\n",
|
310 | 293 | "ax.set_xlabel(\"$x$\")\n",
|
311 | 294 | "ax.set_ylabel(\"$y$\")\n",
|
312 | 295 | "ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
|
313 |
| - "for y1 in y_sampled:\n", |
| 296 | + "for y1 in f_samples:\n", |
314 | 297 | " ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
|
315 |
| - "l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
316 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
| 298 | + "l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
| 299 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
317 | 300 | "ax.legend(loc='upper left')\n",
|
318 | 301 | "l.set_alpha(0)\n",
|
319 | 302 | "ax.set_ylim(-1.8, 2.2);"
|
|
345 | 328 | "cell_type": "code",
|
346 | 329 | "metadata": {
|
347 | 330 | "id": "7R0jWHFLtQ5b",
|
348 |
| - "outputId": "a2123685-171f-4d42-be56-88d63acecb70", |
| 331 | + "outputId": "d4878597-20e5-41cf-bd66-c23aea38556e", |
349 | 332 | "colab": {
|
350 | 333 | "base_uri": "https://localhost:8080/",
|
351 | 334 | "height": 449
|
|
356 | 339 | "ax.set_xlabel(\"$x$\")\n",
|
357 | 340 | "ax.set_ylabel(\"$y$\")\n",
|
358 | 341 | "ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
|
359 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
360 |
| - "ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n", |
| 342 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
| 343 | + "ax.fill_between(X_test,\n", |
| 344 | + " posterior_mean - f_samples.std(axis=(0,1)),\n", |
| 345 | + " posterior_mean + f_samples.std(axis=(0,1)),\n", |
361 | 346 | " color='r', alpha=0.3, label=\"Model uncertainty\")\n",
|
362 | 347 | "ax.legend(loc='upper left')\n",
|
363 | 348 | "ax.set_ylim(-1.8, 2.2);"
|
|
426 | 411 | ],
|
427 | 412 | "metadata": {
|
428 | 413 | "id": "znLfcvK0HSqY",
|
429 |
| - "outputId": "becc4a51-9c6d-40ac-93c6-2972a66881a0", |
| 414 | + "outputId": "e4197eaf-64cd-4ea2-93a4-887ed12e9521", |
430 | 415 | "colab": {
|
431 | 416 | "base_uri": "https://localhost:8080/",
|
432 | 417 | "height": 449
|
|
470 | 455 | "colab": {
|
471 | 456 | "base_uri": "https://localhost:8080/"
|
472 | 457 | },
|
473 |
| - "outputId": "4506ef74-5e1e-4bb5-b239-8fcfbc045c3d", |
| 458 | + "outputId": "5ccd70c9-a530-40dc-b403-a6cbba39fcd1", |
474 | 459 | "id": "Qwx5D237IVdC"
|
475 | 460 | },
|
476 | 461 | "source": [
|
477 | 462 | "# Get random number generator keys for training and prediction\n",
|
478 |
| - "rng_key, rng_key_predict = gpax.utils.get_keys()\n", |
| 463 | + "key1, key2 = gpax.utils.get_keys()\n", |
479 | 464 | "\n",
|
480 | 465 | "# Initialize model\n",
|
481 | 466 | "gp_model = gpax.ExactGP(1, kernel='RBF')\n",
|
482 | 467 | "\n",
|
483 | 468 | "# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
|
484 |
| - "gp_model.fit(rng_key, X, y, num_chains=1)" |
| 469 | + "gp_model.fit(key1, X, y, num_chains=1)" |
485 | 470 | ],
|
486 | 471 | "execution_count": 9,
|
487 | 472 | "outputs": [
|
488 | 473 | {
|
489 | 474 | "output_type": "stream",
|
490 | 475 | "name": "stderr",
|
491 | 476 | "text": [
|
492 |
| - "sample: 100%|██████████| 4000/4000 [00:07<00:00, 547.81it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n" |
| 477 | + "sample: 100%|██████████| 4000/4000 [00:05<00:00, 706.59it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n" |
493 | 478 | ]
|
494 | 479 | },
|
495 | 480 | {
|
|
523 | 508 | "source": [
|
524 | 509 | "X_test = np.linspace(-1, 1, 100)\n",
|
525 | 510 | "\n",
|
526 |
| - "y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)" |
| 511 | + "posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)" |
527 | 512 | ],
|
528 | 513 | "execution_count": 10,
|
529 | 514 | "outputs": []
|
|
544 | 529 | "base_uri": "https://localhost:8080/",
|
545 | 530 | "height": 449
|
546 | 531 | },
|
547 |
| - "outputId": "d77186a5-f0f2-465b-a653-cebe8503fda6", |
| 532 | + "outputId": "8ee0ee9f-16a8-4628-abe2-235e4ef75081", |
548 | 533 | "id": "rv-9uBCHIhIo"
|
549 | 534 | },
|
550 | 535 | "source": [
|
551 | 536 | "_, ax = plt.subplots(dpi=100)\n",
|
552 | 537 | "ax.set_xlabel(\"$x$\")\n",
|
553 | 538 | "ax.set_ylabel(\"$y$\")\n",
|
554 | 539 | "ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
|
555 |
| - "for y1 in y_sampled:\n", |
| 540 | + "for y1 in f_samples:\n", |
556 | 541 | " ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
|
557 |
| - "l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
558 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
| 542 | + "l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
| 543 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
559 | 544 | "ax.legend(loc='upper left')\n",
|
560 | 545 | "l.set_alpha(0)\n",
|
561 | 546 | "ax.set_ylim(-1.2, 1.2);"
|
|
595 | 580 | {
|
596 | 581 | "cell_type": "code",
|
597 | 582 | "metadata": {
|
598 |
| - "outputId": "2790be52-fe38-4c51-c09b-5501738cf87f", |
| 583 | + "outputId": "103d087a-9a93-4be5-c1bb-c31a5f8b5438", |
599 | 584 | "colab": {
|
600 | 585 | "base_uri": "https://localhost:8080/",
|
601 | 586 | "height": 449
|
|
607 | 592 | "ax.set_xlabel(\"$x$\")\n",
|
608 | 593 | "ax.set_ylabel(\"$y$\")\n",
|
609 | 594 | "ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
|
610 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
611 |
| - "ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n", |
| 595 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
| 596 | + "ax.fill_between(X_test,\n", |
| 597 | + " posterior_mean - f_samples.std(axis=(0,1)),\n", |
| 598 | + " posterior_mean + f_samples.std(axis=(0,1)),\n", |
612 | 599 | " color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
|
613 | 600 | "ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
|
614 | 601 | "ax.legend(loc='upper left')\n",
|
|
684 | 671 | ],
|
685 | 672 | "metadata": {
|
686 | 673 | "id": "UnBPC1RoXAZ8",
|
687 |
| - "outputId": "3e72357a-d26c-41ec-c2f6-e7fcc6b8aa10", |
| 674 | + "outputId": "98e3fa51-fbfd-4fa3-9607-fbdbdb70f6b9", |
688 | 675 | "colab": {
|
689 | 676 | "base_uri": "https://localhost:8080/",
|
690 | 677 | "height": 430
|
|
717 | 704 | "cell_type": "code",
|
718 | 705 | "source": [
|
719 | 706 | "# Get random number generator keys for training and prediction\n",
|
720 |
| - "rng_key, rng_key_predict = gpax.utils.get_keys()\n", |
| 707 | + "key1, key2 = gpax.utils.get_keys()\n", |
721 | 708 | "\n",
|
722 | 709 | "# Initialize model\n",
|
723 | 710 | "gp_model = gpax.ExactGP(1, kernel='RBF', lengthscale_prior_dist=lengthscale_prior_dist)\n",
|
724 | 711 | "\n",
|
725 | 712 | "# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
|
726 |
| - "gp_model.fit(rng_key, X, y, num_chains=1)" |
| 713 | + "gp_model.fit(key1, X, y, num_chains=1)" |
727 | 714 | ],
|
728 | 715 | "metadata": {
|
729 | 716 | "id": "nV9SLaAEnv6v",
|
730 |
| - "outputId": "f02ef5d1-4bac-4fdd-bf07-9704293e442f", |
| 717 | + "outputId": "ec8dae3c-7ae9-458c-898d-265ab62190c1", |
731 | 718 | "colab": {
|
732 | 719 | "base_uri": "https://localhost:8080/"
|
733 | 720 | }
|
|
738 | 725 | "output_type": "stream",
|
739 | 726 | "name": "stderr",
|
740 | 727 | "text": [
|
741 |
| - "sample: 100%|██████████| 4000/4000 [00:07<00:00, 550.79it/s, 7 steps of size 4.26e-01. acc. prob=0.94] \n" |
| 728 | + "sample: 100%|██████████| 4000/4000 [00:06<00:00, 647.22it/s, 7 steps of size 4.26e-01. acc. prob=0.94]\n" |
742 | 729 | ]
|
743 | 730 | },
|
744 | 731 | {
|
|
767 | 754 | {
|
768 | 755 | "cell_type": "code",
|
769 | 756 | "source": [
|
770 |
| - "y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)" |
| 757 | + "posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)" |
771 | 758 | ],
|
772 | 759 | "metadata": {
|
773 | 760 | "id": "W9woiXs5oFrR"
|
|
791 | 778 | "ax.set_xlabel(\"$x$\")\n",
|
792 | 779 | "ax.set_ylabel(\"$y$\")\n",
|
793 | 780 | "ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
|
794 |
| - "for y1 in y_sampled:\n", |
| 781 | + "for y1 in f_samples:\n", |
795 | 782 | " ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
|
796 |
| - "l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
797 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
| 783 | + "l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n", |
| 784 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n", |
798 | 785 | "ax.legend(loc='upper left')\n",
|
799 | 786 | "l.set_alpha(0)\n",
|
800 | 787 | "ax.set_ylim(-1.2, 1.2);"
|
|
805 | 792 | "height": 449
|
806 | 793 | },
|
807 | 794 | "id": "y9gHGmwD3lv-",
|
808 |
| - "outputId": "b02ecbbf-4091-42f3-f398-037f915cff6c" |
| 795 | + "outputId": "24b2604a-a4b6-4715-d1b4-0d3476b8ef62" |
809 | 796 | },
|
810 |
| - "execution_count": 17, |
| 797 | + "execution_count": 18, |
811 | 798 | "outputs": [
|
812 | 799 | {
|
813 | 800 | "output_type": "display_data",
|
|
846 | 833 | "ax.set_xlabel(\"$x$\")\n",
|
847 | 834 | "ax.set_ylabel(\"$y$\")\n",
|
848 | 835 | "ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
|
849 |
| - "ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
850 |
| - "ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n", |
| 836 | + "ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n", |
| 837 | + "ax.fill_between(X_test,\n", |
| 838 | + " posterior_mean - f_samples.std(axis=(0,1)),\n", |
| 839 | + " posterior_mean + f_samples.std(axis=(0,1)),\n", |
851 | 840 | " color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
|
852 | 841 | "ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
|
853 | 842 | "ax.legend(loc='upper left')\n",
|
854 | 843 | "ax.set_ylim(-1.2, 1.2);"
|
855 | 844 | ],
|
856 | 845 | "metadata": {
|
857 | 846 | "id": "fm0e70PIoJvE",
|
858 |
| - "outputId": "85b7ca44-26f0-4e44-d4b9-47f82ca8b5d1", |
| 847 | + "outputId": "8584ac26-b827-454d-aa78-75800863bd69", |
859 | 848 | "colab": {
|
860 | 849 | "base_uri": "https://localhost:8080/",
|
861 | 850 | "height": 449
|
862 | 851 | }
|
863 | 852 | },
|
864 |
| - "execution_count": 18, |
| 853 | + "execution_count": 19, |
865 | 854 | "outputs": [
|
866 | 855 | {
|
867 | 856 | "output_type": "display_data",
|
|
0 commit comments