|
12 | 12 | },
|
13 | 13 | {
|
14 | 14 | "cell_type": "code",
|
15 |
| - "execution_count": 1, |
| 15 | + "execution_count": null, |
16 | 16 | "metadata": {
|
17 | 17 | "id": "136rQl-Z67Xf"
|
18 | 18 | },
|
|
70 | 70 | },
|
71 | 71 | {
|
72 | 72 | "cell_type": "code",
|
73 |
| - "execution_count": 2, |
| 73 | + "execution_count": null, |
74 | 74 | "metadata": {
|
75 | 75 | "id": "VQ1rLUzqha2i"
|
76 | 76 | },
|
|
90 | 90 | },
|
91 | 91 | {
|
92 | 92 | "cell_type": "code",
|
93 |
| - "execution_count": 3, |
| 93 | + "execution_count": null, |
94 | 94 | "metadata": {
|
95 | 95 | "id": "XCoyWlKt67Xk"
|
96 | 96 | },
|
|
110 | 110 | },
|
111 | 111 | {
|
112 | 112 | "cell_type": "code",
|
113 |
| - "execution_count": 4, |
| 113 | + "execution_count": null, |
114 | 114 | "metadata": {
|
115 | 115 | "id": "KtGDc11Ehh7r"
|
116 | 116 | },
|
|
133 | 133 | },
|
134 | 134 | {
|
135 | 135 | "cell_type": "code",
|
136 |
| - "execution_count": 5, |
| 136 | + "execution_count": null, |
137 | 137 | "metadata": {
|
138 | 138 | "id": "V5isV5Ho67Xl"
|
139 | 139 | },
|
|
144 | 144 | },
|
145 | 145 | {
|
146 | 146 | "cell_type": "code",
|
147 |
| - "execution_count": 6, |
| 147 | + "execution_count": null, |
148 | 148 | "metadata": {
|
149 | 149 | "id": "gUyKDZjM67Xl"
|
150 | 150 | },
|
|
179 | 179 | },
|
180 | 180 | {
|
181 | 181 | "cell_type": "code",
|
182 |
| - "execution_count": 7, |
| 182 | + "execution_count": null, |
183 | 183 | "metadata": {
|
184 | 184 | "id": "LAvbGDom67Xl"
|
185 | 185 | },
|
|
205 | 205 | },
|
206 | 206 | {
|
207 | 207 | "cell_type": "code",
|
208 |
| - "execution_count": 8, |
| 208 | + "execution_count": null, |
209 | 209 | "metadata": {
|
210 | 210 | "colab": {
|
211 | 211 | "base_uri": "https://localhost:8080/",
|
|
248 | 248 | },
|
249 | 249 | {
|
250 | 250 | "cell_type": "code",
|
251 |
| - "execution_count": 9, |
| 251 | + "execution_count": null, |
252 | 252 | "metadata": {
|
253 | 253 | "colab": {
|
254 | 254 | "base_uri": "https://localhost:8080/"
|
|
279 | 279 | ],
|
280 | 280 | "source": [
|
281 | 281 | "# Get random number generator keys (see JAX documentation for why it is neccessary)\n",
|
282 |
| - "rng_key, rng_keposterior_meanict = gpax.utils.get_keys()\n", |
| 282 | + "rng_key, rng_key_predict = gpax.utils.get_keys()\n", |
283 | 283 | "\n",
|
284 | 284 | "# Initialize model\n",
|
285 | 285 | "gp_model = gpax.ExactGP(1, kernel='Matern')\n",
|
|
288 | 288 | "gp_model.fit(rng_key, X, y, num_chains=1)\n",
|
289 | 289 | "\n",
|
290 | 290 | "# Get GP prediction\n",
|
291 |
| - "posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)" |
| 291 | + "posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)" |
292 | 292 | ]
|
293 | 293 | },
|
294 | 294 | {
|
|
302 | 302 | },
|
303 | 303 | {
|
304 | 304 | "cell_type": "code",
|
305 |
| - "execution_count": 10, |
| 305 | + "execution_count": null, |
306 | 306 | "metadata": {
|
307 | 307 | "id": "lnxdYcLL67Xm",
|
308 | 308 | "outputId": "584a3a74-32e5-4f13-d3e5-0d2bcdf9cfe7",
|
|
348 | 348 | },
|
349 | 349 | {
|
350 | 350 | "cell_type": "code",
|
351 |
| - "execution_count": 11, |
| 351 | + "execution_count": null, |
352 | 352 | "metadata": {
|
353 | 353 | "id": "OjxPG_gY3U2c"
|
354 | 354 | },
|
|
371 | 371 | },
|
372 | 372 | {
|
373 | 373 | "cell_type": "code",
|
374 |
| - "execution_count": 12, |
| 374 | + "execution_count": null, |
375 | 375 | "metadata": {
|
376 | 376 | "id": "zdrtXqGPKzUe"
|
377 | 377 | },
|
|
401 | 401 | },
|
402 | 402 | {
|
403 | 403 | "cell_type": "code",
|
404 |
| - "execution_count": 13, |
| 404 | + "execution_count": null, |
405 | 405 | "metadata": {
|
406 | 406 | "id": "lqXxUSGeqGhm"
|
407 | 407 | },
|
|
436 | 436 | },
|
437 | 437 | {
|
438 | 438 | "cell_type": "code",
|
439 |
| - "execution_count": 14, |
| 439 | + "execution_count": null, |
440 | 440 | "metadata": {
|
441 | 441 | "colab": {
|
442 | 442 | "base_uri": "https://localhost:8080/",
|
|
524 | 524 | " gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
|
525 | 525 | "\n",
|
526 | 526 | " # Get GP prediction\n",
|
527 |
| - " posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)\n", |
| 527 | + " posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)\n", |
528 | 528 | "\n",
|
529 | 529 | " # Plot results\n",
|
530 | 530 | " _, ax = plt.subplots(dpi=100)\n",
|
|
572 | 572 | },
|
573 | 573 | {
|
574 | 574 | "cell_type": "code",
|
575 |
| - "execution_count": 15, |
| 575 | + "execution_count": null, |
576 | 576 | "metadata": {
|
577 | 577 | "id": "qRocZMUIVsp4"
|
578 | 578 | },
|
|
584 | 584 | },
|
585 | 585 | {
|
586 | 586 | "cell_type": "code",
|
587 |
| - "execution_count": 16, |
| 587 | + "execution_count": null, |
588 | 588 | "metadata": {
|
589 | 589 | "colab": {
|
590 | 590 | "base_uri": "https://localhost:8080/"
|
|
745 | 745 | }
|
746 | 746 | ],
|
747 | 747 | "source": [
|
748 |
| - "rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n", |
| 748 | + "rng_key, rng_key_predict = gpax.utils.get_keys(1)\n", |
749 | 749 | "\n",
|
750 | 750 | "for i in range(6):\n",
|
751 | 751 | " print(\"\\nExploration step {}\".format(i+1))\n",
|
|
754 | 754 | " gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
|
755 | 755 | " # Compute acquisition function (here it is simply the uncertinty in prediciton)\n",
|
756 | 756 | " # and get the coordinate of the next point to measure\n",
|
757 |
| - " obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n", |
| 757 | + " obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n", |
758 | 758 | " next_point_idx = obj.argmax()\n",
|
759 | 759 | " # Append the 'suggested' point\n",
|
760 | 760 | " X = np.append(X, X_test[next_point_idx])\n",
|
|
773 | 773 | },
|
774 | 774 | {
|
775 | 775 | "cell_type": "code",
|
776 |
| - "execution_count": 17, |
| 776 | + "execution_count": null, |
777 | 777 | "metadata": {
|
778 | 778 | "colab": {
|
779 | 779 | "base_uri": "https://localhost:8080/"
|
|
806 | 806 | }
|
807 | 807 | ],
|
808 | 808 | "source": [
|
809 |
| - "rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n", |
| 809 | + "rng_key, rng_key_predict = gpax.utils.get_keys(1)\n", |
810 | 810 | "# Update GP posterior\n",
|
811 | 811 | "gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise1, mean_fn_prior=piecewise1_priors)\n",
|
812 | 812 | "gp_model.fit(rng_key, X, y)\n",
|
813 | 813 | "# Get GP prediction\n",
|
814 |
| - "posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)" |
| 814 | + "posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)" |
815 | 815 | ]
|
816 | 816 | },
|
817 | 817 | {
|
818 | 818 | "cell_type": "code",
|
819 |
| - "execution_count": 18, |
| 819 | + "execution_count": null, |
820 | 820 | "metadata": {
|
821 | 821 | "colab": {
|
822 | 822 | "base_uri": "https://localhost:8080/",
|
|
871 | 871 | },
|
872 | 872 | {
|
873 | 873 | "cell_type": "code",
|
874 |
| - "execution_count": 19, |
| 874 | + "execution_count": null, |
875 | 875 | "metadata": {
|
876 | 876 | "colab": {
|
877 | 877 | "base_uri": "https://localhost:8080/"
|
|
1106 | 1106 | "source": [
|
1107 | 1107 | "X, y = Xo, yo # start from the same set of observations\n",
|
1108 | 1108 | "\n",
|
1109 |
| - "rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n", |
| 1109 | + "rng_key, rng_key_predict = gpax.utils.get_keys(1)\n", |
1110 | 1110 | "\n",
|
1111 | 1111 | "for i in range(9):\n",
|
1112 | 1112 | " print(\"\\nExploration step {}\".format(i+1))\n",
|
1113 | 1113 | " # Obtain/update GP posterior\n",
|
1114 | 1114 | " gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
|
1115 | 1115 | " gp_model.fit(rng_key, X, y, print_summary=1, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
|
1116 | 1116 | " # Compute acquisition function and get coordinate of the next point\n",
|
1117 |
| - " obj = gpax.acquisition.UE(rng_keposterior_meanict, gp_model, X_test)\n", |
| 1117 | + " obj = gpax.acquisition.UE(rng_key_predict, gp_model, X_test)\n", |
1118 | 1118 | " next_point_idx = obj.argmax()\n",
|
1119 | 1119 | " # Append the 'suggested' point\n",
|
1120 | 1120 | " X = np.append(X, X_test[next_point_idx])\n",
|
|
1133 | 1133 | },
|
1134 | 1134 | {
|
1135 | 1135 | "cell_type": "code",
|
1136 |
| - "execution_count": 20, |
| 1136 | + "execution_count": null, |
1137 | 1137 | "metadata": {
|
1138 | 1138 | "colab": {
|
1139 | 1139 | "base_uri": "https://localhost:8080/"
|
|
1166 | 1166 | }
|
1167 | 1167 | ],
|
1168 | 1168 | "source": [
|
1169 |
| - "rng_key, rng_keposterior_meanict = gpax.utils.get_keys(1)\n", |
| 1169 | + "rng_key, rng_key_predict = gpax.utils.get_keys(1)\n", |
1170 | 1170 | "# Update GP posterior\n",
|
1171 | 1171 | "gp_model = gpax.ExactGP(1, kernel='Matern', mean_fn=piecewise2, mean_fn_prior=piecewise2_priors)\n",
|
1172 | 1172 | "gp_model.fit(rng_key, X, y, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES)\n",
|
1173 | 1173 | "# Get GP prediction\n",
|
1174 |
| - "posterior_mean, f_samples = gp_model.predict(rng_keposterior_meanict, X_test, n=200)" |
| 1174 | + "posterior_mean, f_samples = gp_model.predict(rng_key_predict, X_test, n=200)" |
1175 | 1175 | ]
|
1176 | 1176 | },
|
1177 | 1177 | {
|
1178 | 1178 | "cell_type": "code",
|
1179 |
| - "execution_count": 21, |
| 1179 | + "execution_count": null, |
1180 | 1180 | "metadata": {
|
1181 | 1181 | "colab": {
|
1182 | 1182 | "base_uri": "https://localhost:8080/",
|
|
0 commit comments