Skip to content

Commit 023752e

Browse files
committed
Update example
1 parent 97d02f8 commit 023752e

File tree

1 file changed

+53
-64
lines changed

1 file changed

+53
-64
lines changed

examples/gpax_simpleGP.ipynb

+53-64
Original file line numberDiff line numberDiff line change
@@ -70,30 +70,13 @@
7070
{
7171
"cell_type": "code",
7272
"metadata": {
73-
"id": "VQ1rLUzqha2i",
74-
"outputId": "462cff6e-ea94-48f5-a6b0-a8bd1f255bb4",
75-
"colab": {
76-
"base_uri": "https://localhost:8080/"
77-
}
73+
"id": "VQ1rLUzqha2i"
7874
},
7975
"source": [
80-
"!pip install -q git+https://github.com/ziatdinovmax/gpax.git"
76+
"!pip install -q gpax"
8177
],
82-
"execution_count": 1,
83-
"outputs": [
84-
{
85-
"output_type": "stream",
86-
"name": "stdout",
87-
"text": [
88-
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
89-
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
90-
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
91-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m312.7/312.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
92-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m371.0/371.0 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
93-
"\u001b[?25h Building wheel for gpax (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
94-
]
95-
}
96-
]
78+
"execution_count": null,
79+
"outputs": []
9780
},
9881
{
9982
"cell_type": "markdown",
@@ -154,7 +137,7 @@
154137
"height": 449
155138
},
156139
"id": "-I4RQ2xCi0VV",
157-
"outputId": "5fb08e30-3dab-4147-a1a1-93e23817b330"
140+
"outputId": "e80e0974-0e92-47fd-c1e0-27885f78b6c4"
158141
},
159142
"source": [
160143
"np.random.seed(0)\n",
@@ -214,24 +197,24 @@
214197
"base_uri": "https://localhost:8080/"
215198
},
216199
"id": "c7kXm_lui6Dy",
217-
"outputId": "0b260517-4273-4ce7-b04e-1911da22a3a5"
200+
"outputId": "c8daf638-fea1-420b-9c3b-099999352d6a"
218201
},
219202
"source": [
220203
"# Get random number generator keys for training and prediction\n",
221-
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
204+
"key1, key2 = gpax.utils.get_keys()\n",
222205
"\n",
223206
"# Initialize model\n",
224207
"gp_model = gpax.ExactGP(1, kernel='RBF')\n",
225208
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
226-
"gp_model.fit(rng_key, X, y, num_chains=1)"
209+
"gp_model.fit(key1, X, y, num_chains=1)"
227210
],
228211
"execution_count": 4,
229212
"outputs": [
230213
{
231214
"output_type": "stream",
232215
"name": "stderr",
233216
"text": [
234-
"sample: 100%|██████████| 4000/4000 [00:14<00:00, 277.25it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n"
217+
"sample: 100%|██████████| 4000/4000 [00:10<00:00, 375.60it/s, 7 steps of size 5.52e-01. acc. prob=0.90] \n"
235218
]
236219
},
237220
{
@@ -263,12 +246,12 @@
263246
"\n",
264247
"$$𝜇^{post}_*= \\frac{1}{L} ∑_{i=1}^L 𝜇_*^i,$$\n",
265248
"\n",
266-
"which corresponds to the ```y_pred``` in the code cell below, and\n",
249+
"which corresponds to the ```posterior_mean``` in the code cell below, and\n",
267250
"samples\n",
268251
"\n",
269252
"$$f_*^i∼MVNormal(𝜇^i_*, 𝛴^i_*)$$\n",
270253
"\n",
271-
"from multivariate normal distributions for all the pairs of predictive means and covariances (```y_sampled``` in the code cell below). Note that model noise is absorbed into the kernel computation function."
254+
"from multivariate normal distributions for all the pairs of predictive means and covariances (```f_samples``` in the code cell below). Note that model noise is absorbed into the kernel computation function."
272255
]
273256
},
274257
{
@@ -281,7 +264,7 @@
281264
"X_test = np.linspace(-1, 1, 100)\n",
282265
"# Get the GP prediction. Here n stands for the number of samples from each MVNormal distribution\n",
283266
"# (the total number of MVNormal distributions is equal to the number of HMC samples)\n",
284-
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
267+
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
285268
],
286269
"execution_count": 5,
287270
"outputs": []
@@ -303,17 +286,17 @@
303286
"height": 449
304287
},
305288
"id": "lJIdx7fUnX-W",
306-
"outputId": "29465923-a54e-44c3-ffe6-6d38d06dcbce"
289+
"outputId": "f8076d8c-cbc1-49de-9b71-2c6d6665bbb5"
307290
},
308291
"source": [
309292
"_, ax = plt.subplots(dpi=100)\n",
310293
"ax.set_xlabel(\"$x$\")\n",
311294
"ax.set_ylabel(\"$y$\")\n",
312295
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
313-
"for y1 in y_sampled:\n",
296+
"for y1 in f_samples:\n",
314297
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
315-
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
316-
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
298+
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
299+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
317300
"ax.legend(loc='upper left')\n",
318301
"l.set_alpha(0)\n",
319302
"ax.set_ylim(-1.8, 2.2);"
@@ -345,7 +328,7 @@
345328
"cell_type": "code",
346329
"metadata": {
347330
"id": "7R0jWHFLtQ5b",
348-
"outputId": "a2123685-171f-4d42-be56-88d63acecb70",
331+
"outputId": "d4878597-20e5-41cf-bd66-c23aea38556e",
349332
"colab": {
350333
"base_uri": "https://localhost:8080/",
351334
"height": 449
@@ -356,8 +339,10 @@
356339
"ax.set_xlabel(\"$x$\")\n",
357340
"ax.set_ylabel(\"$y$\")\n",
358341
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
359-
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
360-
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
342+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
343+
"ax.fill_between(X_test,\n",
344+
" posterior_mean - f_samples.std(axis=(0,1)),\n",
345+
" posterior_mean + f_samples.std(axis=(0,1)),\n",
361346
" color='r', alpha=0.3, label=\"Model uncertainty\")\n",
362347
"ax.legend(loc='upper left')\n",
363348
"ax.set_ylim(-1.8, 2.2);"
@@ -426,7 +411,7 @@
426411
],
427412
"metadata": {
428413
"id": "znLfcvK0HSqY",
429-
"outputId": "becc4a51-9c6d-40ac-93c6-2972a66881a0",
414+
"outputId": "e4197eaf-64cd-4ea2-93a4-887ed12e9521",
430415
"colab": {
431416
"base_uri": "https://localhost:8080/",
432417
"height": 449
@@ -470,26 +455,26 @@
470455
"colab": {
471456
"base_uri": "https://localhost:8080/"
472457
},
473-
"outputId": "4506ef74-5e1e-4bb5-b239-8fcfbc045c3d",
458+
"outputId": "5ccd70c9-a530-40dc-b403-a6cbba39fcd1",
474459
"id": "Qwx5D237IVdC"
475460
},
476461
"source": [
477462
"# Get random number generator keys for training and prediction\n",
478-
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
463+
"key1, key2 = gpax.utils.get_keys()\n",
479464
"\n",
480465
"# Initialize model\n",
481466
"gp_model = gpax.ExactGP(1, kernel='RBF')\n",
482467
"\n",
483468
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
484-
"gp_model.fit(rng_key, X, y, num_chains=1)"
469+
"gp_model.fit(key1, X, y, num_chains=1)"
485470
],
486471
"execution_count": 9,
487472
"outputs": [
488473
{
489474
"output_type": "stream",
490475
"name": "stderr",
491476
"text": [
492-
"sample: 100%|██████████| 4000/4000 [00:07<00:00, 547.81it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n"
477+
"sample: 100%|██████████| 4000/4000 [00:05<00:00, 706.59it/s, 15 steps of size 2.69e-01. acc. prob=0.78]\n"
493478
]
494479
},
495480
{
@@ -523,7 +508,7 @@
523508
"source": [
524509
"X_test = np.linspace(-1, 1, 100)\n",
525510
"\n",
526-
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
511+
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
527512
],
528513
"execution_count": 10,
529514
"outputs": []
@@ -544,18 +529,18 @@
544529
"base_uri": "https://localhost:8080/",
545530
"height": 449
546531
},
547-
"outputId": "d77186a5-f0f2-465b-a653-cebe8503fda6",
532+
"outputId": "8ee0ee9f-16a8-4628-abe2-235e4ef75081",
548533
"id": "rv-9uBCHIhIo"
549534
},
550535
"source": [
551536
"_, ax = plt.subplots(dpi=100)\n",
552537
"ax.set_xlabel(\"$x$\")\n",
553538
"ax.set_ylabel(\"$y$\")\n",
554539
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
555-
"for y1 in y_sampled:\n",
540+
"for y1 in f_samples:\n",
556541
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
557-
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
558-
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
542+
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
543+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
559544
"ax.legend(loc='upper left')\n",
560545
"l.set_alpha(0)\n",
561546
"ax.set_ylim(-1.2, 1.2);"
@@ -595,7 +580,7 @@
595580
{
596581
"cell_type": "code",
597582
"metadata": {
598-
"outputId": "2790be52-fe38-4c51-c09b-5501738cf87f",
583+
"outputId": "103d087a-9a93-4be5-c1bb-c31a5f8b5438",
599584
"colab": {
600585
"base_uri": "https://localhost:8080/",
601586
"height": 449
@@ -607,8 +592,10 @@
607592
"ax.set_xlabel(\"$x$\")\n",
608593
"ax.set_ylabel(\"$y$\")\n",
609594
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
610-
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
611-
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
595+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
596+
"ax.fill_between(X_test,\n",
597+
" posterior_mean - f_samples.std(axis=(0,1)),\n",
598+
" posterior_mean + f_samples.std(axis=(0,1)),\n",
612599
" color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
613600
"ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
614601
"ax.legend(loc='upper left')\n",
@@ -684,7 +671,7 @@
684671
],
685672
"metadata": {
686673
"id": "UnBPC1RoXAZ8",
687-
"outputId": "3e72357a-d26c-41ec-c2f6-e7fcc6b8aa10",
674+
"outputId": "98e3fa51-fbfd-4fa3-9607-fbdbdb70f6b9",
688675
"colab": {
689676
"base_uri": "https://localhost:8080/",
690677
"height": 430
@@ -717,17 +704,17 @@
717704
"cell_type": "code",
718705
"source": [
719706
"# Get random number generator keys for training and prediction\n",
720-
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
707+
"key1, key2 = gpax.utils.get_keys()\n",
721708
"\n",
722709
"# Initialize model\n",
723710
"gp_model = gpax.ExactGP(1, kernel='RBF', lengthscale_prior_dist=lengthscale_prior_dist)\n",
724711
"\n",
725712
"# Run Hamiltonian Monte Carlo to obtain posterior samples for kernel parameters and model noise\n",
726-
"gp_model.fit(rng_key, X, y, num_chains=1)"
713+
"gp_model.fit(key1, X, y, num_chains=1)"
727714
],
728715
"metadata": {
729716
"id": "nV9SLaAEnv6v",
730-
"outputId": "f02ef5d1-4bac-4fdd-bf07-9704293e442f",
717+
"outputId": "ec8dae3c-7ae9-458c-898d-265ab62190c1",
731718
"colab": {
732719
"base_uri": "https://localhost:8080/"
733720
}
@@ -738,7 +725,7 @@
738725
"output_type": "stream",
739726
"name": "stderr",
740727
"text": [
741-
"sample: 100%|██████████| 4000/4000 [00:07<00:00, 550.79it/s, 7 steps of size 4.26e-01. acc. prob=0.94] \n"
728+
"sample: 100%|██████████| 4000/4000 [00:06<00:00, 647.22it/s, 7 steps of size 4.26e-01. acc. prob=0.94]\n"
742729
]
743730
},
744731
{
@@ -767,7 +754,7 @@
767754
{
768755
"cell_type": "code",
769756
"source": [
770-
"y_pred, y_sampled = gp_model.predict(rng_key_predict, X_test, n=200)"
757+
"posterior_mean, f_samples = gp_model.predict(key2, X_test, n=200)"
771758
],
772759
"metadata": {
773760
"id": "W9woiXs5oFrR"
@@ -791,10 +778,10 @@
791778
"ax.set_xlabel(\"$x$\")\n",
792779
"ax.set_ylabel(\"$y$\")\n",
793780
"ax.scatter(X, y, marker='x', c='k', zorder=1, label=\"Noisy observations\", alpha=0.7)\n",
794-
"for y1 in y_sampled:\n",
781+
"for y1 in f_samples:\n",
795782
" ax.plot(X_test, y1.mean(0), lw=.1, zorder=0, c='r', alpha=.1)\n",
796-
"l, = ax.plot(X_test, y_sampled[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
797-
"ax.plot(X_test, y_pred, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
783+
"l, = ax.plot(X_test, f_samples[0].mean(0), lw=1, c='r', alpha=1, label=\"Sampled predictions\")\n",
784+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=1, c='b', label='Posterior mean')\n",
798785
"ax.legend(loc='upper left')\n",
799786
"l.set_alpha(0)\n",
800787
"ax.set_ylim(-1.2, 1.2);"
@@ -805,9 +792,9 @@
805792
"height": 449
806793
},
807794
"id": "y9gHGmwD3lv-",
808-
"outputId": "b02ecbbf-4091-42f3-f398-037f915cff6c"
795+
"outputId": "24b2604a-a4b6-4715-d1b4-0d3476b8ef62"
809796
},
810-
"execution_count": 17,
797+
"execution_count": 18,
811798
"outputs": [
812799
{
813800
"output_type": "display_data",
@@ -846,22 +833,24 @@
846833
"ax.set_xlabel(\"$x$\")\n",
847834
"ax.set_ylabel(\"$y$\")\n",
848835
"ax.scatter(X, y, marker='x', c='k', zorder=2, label=\"Noisy observations\", alpha=0.7)\n",
849-
"ax.plot(X_test, y_pred, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
850-
"ax.fill_between(X_test, y_pred - y_sampled.std(axis=(0,1)), y_pred + y_sampled.std(axis=(0,1)),\n",
836+
"ax.plot(X_test, posterior_mean, lw=1.5, zorder=2, c='b', label='Posterior mean')\n",
837+
"ax.fill_between(X_test,\n",
838+
" posterior_mean - f_samples.std(axis=(0,1)),\n",
839+
" posterior_mean + f_samples.std(axis=(0,1)),\n",
851840
" color='r', alpha=0.3, label=\"Model uncertainty (2$\\sigma$)\")\n",
852841
"ax.plot(X_test, f(X_test), color='k', alpha=0.7, zorder=1, label='Ground truth')\n",
853842
"ax.legend(loc='upper left')\n",
854843
"ax.set_ylim(-1.2, 1.2);"
855844
],
856845
"metadata": {
857846
"id": "fm0e70PIoJvE",
858-
"outputId": "85b7ca44-26f0-4e44-d4b9-47f82ca8b5d1",
847+
"outputId": "8584ac26-b827-454d-aa78-75800863bd69",
859848
"colab": {
860849
"base_uri": "https://localhost:8080/",
861850
"height": 449
862851
}
863852
},
864-
"execution_count": 18,
853+
"execution_count": 19,
865854
"outputs": [
866855
{
867856
"output_type": "display_data",

0 commit comments

Comments
 (0)