Skip to content

Commit df2b2da

Browse files
committed
Update example
1 parent 05dda75 commit df2b2da

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

examples/GP_sGP.ipynb

+32-32
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {
1717
"id": "136rQl-Z67Xf"
1818
},
@@ -70,7 +70,7 @@
7070
},
7171
{
7272
"cell_type": "code",
73-
"execution_count": 2,
73+
"execution_count": null,
7474
"metadata": {
7575
"id": "VQ1rLUzqha2i"
7676
},
@@ -90,7 +90,7 @@
9090
},
9191
{
9292
"cell_type": "code",
93-
"execution_count": 3,
93+
"execution_count": null,
9494
"metadata": {
9595
"id": "XCoyWlKt67Xk"
9696
},
@@ -110,7 +110,7 @@
110110
},
111111
{
112112
"cell_type": "code",
113-
"execution_count": 4,
113+
"execution_count": null,
114114
"metadata": {
115115
"id": "KtGDc11Ehh7r"
116116
},
@@ -133,7 +133,7 @@
133133
},
134134
{
135135
"cell_type": "code",
136-
"execution_count": 5,
136+
"execution_count": null,
137137
"metadata": {
138138
"id": "V5isV5Ho67Xl"
139139
},
@@ -144,7 +144,7 @@
144144
},
145145
{
146146
"cell_type": "code",
147-
"execution_count": 6,
147+
"execution_count": null,
148148
"metadata": {
149149
"id": "gUyKDZjM67Xl"
150150
},
@@ -179,7 +179,7 @@
179179
},
180180
{
181181
"cell_type": "code",
182-
"execution_count": 7,
182+
"execution_count": null,
183183
"metadata": {
184184
"id": "LAvbGDom67Xl"
185185
},
@@ -205,7 +205,7 @@
205205
},
206206
{
207207
"cell_type": "code",
208-
"execution_count": 8,
208+
"execution_count": null,
209209
"metadata": {
210210
"colab": {
211211
"base_uri": "https://localhost:8080/",
@@ -248,7 +248,7 @@
248248
},
249249
{
250250
"cell_type": "code",
251-
"execution_count": 9,
251+
"execution_count": null,
252252
"metadata": {
253253
"colab": {
254254
"base_uri": "https://localhost:8080/"
@@ -279,7 +279,7 @@
279279
],
280280
"source": [
281281
"# Get random number generator keys (see JAX documentation for why it is neccessary)\n",
282-
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys()\n",
282+
"rng_key, rng_key_predict = gpax.utils.get_keys()\n",
283283
"\n",
284284
"# Initialize model\n",
285285
"gp_model = gpax.ExactGP(1, kernel='Matern')\n",
@@ -288,7 +288,7 @@
288288
"gp_model.fit(rng_key, X, y, num_chains=1)\n",
289289
"\n",
290290
"# Get GP prediction\n",
291-
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
291+
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
292292
]
293293
},
294294
{
@@ -302,7 +302,7 @@
302302
},
303303
{
304304
"cell_type": "code",
305-
"execution_count": 10,
305+
"execution_count": null,
306306
"metadata": {
307307
"id": "lnxdYcLL67Xm",
308308
"outputId": "584a3a74-32e5-4f13-d3e5-0d2bcdf9cfe7",
@@ -348,7 +348,7 @@
348348
},
349349
{
350350
"cell_type": "code",
351-
"execution_count": 11,
351+
"execution_count": null,
352352
"metadata": {
353353
"id": "OjxPG_gY3U2c"
354354
},
@@ -371,7 +371,7 @@
371371
},
372372
{
373373
"cell_type": "code",
374-
"execution_count": 12,
374+
"execution_count": null,
375375
"metadata": {
376376
"id": "zdrtXqGPKzUe"
377377
},
@@ -401,7 +401,7 @@
401401
},
402402
{
403403
"cell_type": "code",
404-
"execution_count": 13,
404+
"execution_count": null,
405405
"metadata": {
406406
"id": "lqXxUSGeqGhm"
407407
},
@@ -436,7 +436,7 @@
436436
},
437437
{
438438
"cell_type": "code",
439-
"execution_count": 14,
439+
"execution_count": null,
440440
"metadata": {
441441
"colab": {
442442
"base_uri": "https://localhost:8080/",
@@ -524,7 +524,7 @@
524524
" gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
525525
"\n",
526526
" # Get GP prediction\n",
527-
" posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)\n",
527+
" posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)\n",
528528
"\n",
529529
" # Plot results\n",
530530
" _, ax = plt.subplots(dpi=100)\n",
@@ -572,7 +572,7 @@
572572
},
573573
{
574574
"cell_type": "code",
575-
"execution_count": 15,
575+
"execution_count": null,
576576
"metadata": {
577577
"id": "qRocZMUIVsp4"
578578
},
@@ -584,7 +584,7 @@
584584
},
585585
{
586586
"cell_type": "code",
587-
"execution_count": 16,
587+
"execution_count": null,
588588
"metadata": {
589589
"colab": {
590590
"base_uri": "https://localhost:8080/"
@@ -745,7 +745,7 @@
745745
}
746746
],
747747
"source": [
748-
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
748+
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
749749
"\n",
750750
"for i in range(6):\n",
751751
" print(\"\\nExploration step {}\".format(i+1))\n",
@@ -754,7 +754,7 @@
754754
" gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
755755
" # Compute acquisition function (here it is simply the uncertinty in prediciton)\n",
756756
" # and get the coordinate of the next point to measure\n",
757-
" obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n",
757+
" obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n",
758758
" next_point_idx = obj.argmax()\n",
759759
" # Append the 'suggested' point\n",
760760
" X = np.append(X, X_test[next_point_idx])\n",
@@ -773,7 +773,7 @@
773773
},
774774
{
775775
"cell_type": "code",
776-
"execution_count": 17,
776+
"execution_count": null,
777777
"metadata": {
778778
"colab": {
779779
"base_uri": "https://localhost:8080/"
@@ -806,17 +806,17 @@
806806
}
807807
],
808808
"source": [
809-
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
809+
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
810810
"# Update GP posterior\n",
811811
"gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise1, mean_fn_prior=piecewise1_priors)\n",
812812
"gp_model.fit(rng_key, X, y)\n",
813813
"# Get GP prediction\n",
814-
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
814+
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
815815
]
816816
},
817817
{
818818
"cell_type": "code",
819-
"execution_count": 18,
819+
"execution_count": null,
820820
"metadata": {
821821
"colab": {
822822
"base_uri": "https://localhost:8080/",
@@ -871,7 +871,7 @@
871871
},
872872
{
873873
"cell_type": "code",
874-
"execution_count": 19,
874+
"execution_count": null,
875875
"metadata": {
876876
"colab": {
877877
"base_uri": "https://localhost:8080/"
@@ -1106,15 +1106,15 @@
11061106
"source": [
11071107
"X, y = Xo, yo # start from the same set of observations\n",
11081108
"\n",
1109-
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
1109+
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
11101110
"\n",
11111111
"for i in range(9):\n",
11121112
" print(\"\\nExploration step {}\".format(i+1))\n",
11131113
" # Obtain/update GP posterior\n",
11141114
" gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
11151115
" gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
11161116
" # Compute acquisition function and get coordinate of the next point\n",
1117-
" obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n",
1117+
" obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n",
11181118
" next_point_idx = obj.argmax()\n",
11191119
" # Append the 'suggested' point\n",
11201120
" X = np.append(X, X_test[next_point_idx])\n",
@@ -1133,7 +1133,7 @@
11331133
},
11341134
{
11351135
"cell_type": "code",
1136-
"execution_count": 20,
1136+
"execution_count": null,
11371137
"metadata": {
11381138
"colab": {
11391139
"base_uri": "https://localhost:8080/"
@@ -1166,17 +1166,17 @@
11661166
}
11671167
],
11681168
"source": [
1169-
"rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n",
1169+
"rng_key, rng_key_predict = gpax.utils.get_keys(1)\n",
11701170
"# Update GP posterior\n",
11711171
"gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
11721172
"gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
11731173
"# Get GP prediction\n",
1174-
"posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)"
1174+
"posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)"
11751175
]
11761176
},
11771177
{
11781178
"cell_type": "code",
1179-
"execution_count": 21,
1179+
"execution_count": null,
11801180
"metadata": {
11811181
"colab": {
11821182
"base_uri": "https://localhost:8080/",

0 commit comments

Comments
 (0)