Skip to content

Commit b03480a

Browse files
authored
Merge pull request #94 from ziatdinovmax/util
Move 'priors' out of 'utils' and turn them into a separate module
2 parents 64bbec2 + cb28ab8 commit b03480a

12 files changed

+276
-235
lines changed

docs/source/index.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ GPax is a small Python package for physics-based Gaussian processes (GPs) built
1919
:caption: Package Content
2020

2121
models
22-
hypo
2322
acquisition
2423
kernels
24+
priors
25+
hypo
2526
utils
2627

2728
.. toctree::

docs/source/models.rst

+10
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,13 @@ Multi-Task Learning
9696
:undoc-members:
9797
:member-order: bysource
9898
:show-inheritance:
99+
100+
Structured Probabilistic Models
101+
-------------------------------
102+
.. autoclass:: gpax.models.spm.sPM
103+
:members:
104+
:inherited-members:
105+
:undoc-members:
106+
:member-order: bysource
107+
:show-inheritance:
108+

docs/source/priors.rst

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Priors
2+
======
3+
4+
.. autofunction:: gpax.utils.normal_dist
5+
6+
.. autofunction:: gpax.utils.lognormal_dist
7+
8+
.. autofunction:: gpax.utils.halfnormal_dist
9+
10+
.. autofunction:: gpax.utils.gamma_dist
11+
12+
.. autofunction:: gpax.utils.uniform_dist
13+
14+
.. autofunction:: gpax.utils.place_normal_prior
15+
16+
.. autofunction:: gpax.utils.place_lognormal_prior
17+
18+
.. autofunction:: gpax.utils.place_halfnormal_prior
19+
20+
.. autofunction:: gpax.utils.place_uniform_prior
21+
22+
.. autofunction:: gpax.utils.place_gamma_prior
23+

docs/source/utils.rst

+4-27
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,17 @@
11
Utilities
22
=========
33

4-
Priors
5-
------
4+
Automatic function setters
5+
--------------------------
66

7-
.. autofunction:: gpax.utils.normal_dist
7+
.. autofunction:: gpax.utils.set_fn
88

9-
.. autofunction:: gpax.utils.lognormal_dist
10-
11-
.. autofunction:: gpax.utils.halfnormal_dist
12-
13-
.. autofunction:: gpax.utils.gamma_dist
14-
15-
.. autofunction:: gpax.utils.uniform_dist
16-
17-
.. autofunction:: gpax.utils.place_normal_prior
18-
19-
.. autofunction:: gpax.utils.place_lognormal_prior
20-
21-
.. autofunction:: gpax.utils.place_halfnormal_prior
22-
23-
.. autofunction:: gpax.utils.place_uniform_prior
24-
25-
.. autofunction:: gpax.utils.place_gamma_prior
9+
.. autofunction:: gpax.utils.set_kernel_fn
2610

2711

2812
Other utilities
2913
---------------
3014

31-
.. autoclass:: gpax.models.spm.sPM
32-
:members:
33-
:inherited-members:
34-
:undoc-members:
35-
:member-order: bysource
36-
:show-inheritance:
37-
3815
.. autofunction:: gpax.utils.dviz
3916

4017
.. autofunction:: gpax.utils.get_keys

gpax/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .__version__ import version as __version__
2+
from . import priors
23
from . import utils
34
from . import kernels
45
from . import acquisition
@@ -7,6 +8,6 @@
78
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
89
MeasuredNoiseGP, viSparseGP, BNN)
910

10-
__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
11+
__all__ = ["priors", "utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
1112
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
1213
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]

gpax/priors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .priors import *

gpax/utils/priors.py gpax/priors/priors.py

+2-137
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,16 @@
44
55
Utility functions for setting priors
66
7-
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
7+
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
88
"""
99

1010
import inspect
11-
import re
1211

13-
from typing import Union, Dict, Type, List, Callable, Optional
12+
from typing import Union, Dict, Type, Callable
1413

1514
import numpyro
16-
import jax
1715
import jax.numpy as jnp
1816

19-
from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt
20-
2117

2218
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
2319
"""
@@ -183,137 +179,6 @@ def uniform_dist(low: float = None,
183179
return numpyro.distributions.Uniform(low, high)
184180

185181

186-
def set_fn(func: Callable) -> Callable:
187-
"""
188-
Transforms the given deterministic function to use a params dictionary
189-
for its parameters, excluding the first one (assumed to be the dependent variable).
190-
191-
Args:
192-
- func (Callable): The deterministic function to be transformed.
193-
194-
Returns:
195-
- Callable: The transformed function where parameters are accessed
196-
from a `params` dictionary.
197-
"""
198-
# Extract parameter names excluding the first one (assumed to be the dependent variable)
199-
params_names = list(inspect.signature(func).parameters.keys())[1:]
200-
201-
# Create the transformed function definition
202-
transformed_code = f"def {func.__name__}(x, params):\n"
203-
204-
# Retrieve the source code of the function and indent it to be a valid function body
205-
source = inspect.getsource(func).split("\n", 1)[1]
206-
source = " " + source.replace("\n", "\n ")
207-
208-
# Replace each parameter name with its dictionary lookup using regex
209-
for name in params_names:
210-
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)
211-
212-
# Combine to get the full source
213-
transformed_code += source
214-
215-
# Define the transformed function in the local namespace
216-
local_namespace = {}
217-
exec(transformed_code, globals(), local_namespace)
218-
219-
# Return the transformed function
220-
return local_namespace[func.__name__]
221-
222-
223-
def set_kernel_fn(func: Callable,
224-
independent_vars: List[str] = ["X", "Z"],
225-
jit_decorator: bool = True,
226-
docstring: Optional[str] = None) -> Callable:
227-
"""
228-
Transforms the given kernel function to use a params dictionary for its hyperparameters.
229-
The resultant function will always add jitter before returning the computed kernel.
230-
231-
Args:
232-
func (Callable): The kernel function to be transformed.
233-
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
234-
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
235-
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
236-
237-
Returns:
238-
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
239-
"""
240-
241-
# Extract parameter names excluding the independent variables
242-
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
243-
for var in independent_vars:
244-
params_names.remove(var)
245-
246-
transformed_code = ""
247-
if jit_decorator:
248-
transformed_code += "@jit" + "\n"
249-
250-
additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
251-
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"
252-
253-
if docstring:
254-
transformed_code += ' """' + docstring + '"""\n'
255-
256-
source = inspect.getsource(func).split("\n", 1)[1]
257-
lines = source.split("\n")
258-
259-
for idx, line in enumerate(lines):
260-
# Convert all parameter names to their dictionary lookup throughout the function body
261-
for name in params_names:
262-
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])
263-
264-
# Combine lines back and then split again by return
265-
modified_source = '\n'.join(lines)
266-
pre_return, return_statement = modified_source.split('return', 1)
267-
268-
# Append custom jitter code
269-
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
270-
custom_code += """
271-
if X.shape == Z.shape:
272-
k += (noise + jitter) * jnp.eye(X.shape[0])
273-
return k
274-
"""
275-
276-
transformed_code += custom_code
277-
278-
local_namespace = {"jit": jax.jit}
279-
exec(transformed_code, globals(), local_namespace)
280-
281-
return local_namespace[func.__name__]
282-
283-
284-
def _set_noise_kernel_fn(func: Callable) -> Callable:
285-
"""
286-
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
287-
288-
Args:
289-
func (Callable): Original function.
290-
291-
Returns:
292-
Callable: Modified function.
293-
"""
294-
295-
# Get the source code of the function
296-
source = inspect.getsource(func)
297-
298-
# Split the source into decorators, definition, and body
299-
decorators_and_def, body = source.split("\n", 1)
300-
301-
# Replace all occurrences of params["k with params["k_noise in the body
302-
modified_body = re.sub(r'params\["k', 'params["k_noise', body)
303-
304-
# Combine decorators, definition, and modified body
305-
modified_source = f"{decorators_and_def}\n{modified_body}"
306-
307-
# Define local namespace including the jit decorator
308-
local_namespace = {"jit": jax.jit}
309-
310-
# Execute the modified source to redefine the function in the provided namespace
311-
exec(modified_source, globals(), local_namespace)
312-
313-
# Return the modified function
314-
return local_namespace[func.__name__]
315-
316-
317182
def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
318183
"""
319184
Generates a function that, when invoked, samples from normal or log-normal distributions

gpax/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .utils import *
2-
from .priors import *
3-
from .priors import _set_noise_kernel_fn
2+
from .fn import *
3+
from .fn import _set_noise_kernel_fn

0 commit comments

Comments
 (0)