Skip to content

Commit 8669f14

Browse files
committed
Move priors out of utils
1 parent 64bbec2 commit 8669f14

File tree

8 files changed

+236
-207
lines changed

8 files changed

+236
-207
lines changed

gpax/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .__version__ import version as __version__
2+
from . import priors
23
from . import utils
34
from . import kernels
45
from . import acquisition
@@ -7,6 +8,6 @@
78
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
89
MeasuredNoiseGP, viSparseGP, BNN)
910

10-
__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
11+
__all__ = ["priors", "utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
1112
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
1213
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]

gpax/priors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .priors import *

gpax/utils/priors.py gpax/priors/priors.py

+2-137
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,16 @@
44
55
Utility functions for setting priors
66
7-
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
7+
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
88
"""
99

1010
import inspect
11-
import re
1211

13-
from typing import Union, Dict, Type, List, Callable, Optional
12+
from typing import Union, Dict, Type, Callable
1413

1514
import numpyro
16-
import jax
1715
import jax.numpy as jnp
1816

19-
from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt
20-
2117

2218
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
2319
"""
@@ -183,137 +179,6 @@ def uniform_dist(low: float = None,
183179
return numpyro.distributions.Uniform(low, high)
184180

185181

186-
def set_fn(func: Callable) -> Callable:
187-
"""
188-
Transforms the given deterministic function to use a params dictionary
189-
for its parameters, excluding the first one (assumed to be the dependent variable).
190-
191-
Args:
192-
- func (Callable): The deterministic function to be transformed.
193-
194-
Returns:
195-
- Callable: The transformed function where parameters are accessed
196-
from a `params` dictionary.
197-
"""
198-
# Extract parameter names excluding the first one (assumed to be the dependent variable)
199-
params_names = list(inspect.signature(func).parameters.keys())[1:]
200-
201-
# Create the transformed function definition
202-
transformed_code = f"def {func.__name__}(x, params):\n"
203-
204-
# Retrieve the source code of the function and indent it to be a valid function body
205-
source = inspect.getsource(func).split("\n", 1)[1]
206-
source = " " + source.replace("\n", "\n ")
207-
208-
# Replace each parameter name with its dictionary lookup using regex
209-
for name in params_names:
210-
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)
211-
212-
# Combine to get the full source
213-
transformed_code += source
214-
215-
# Define the transformed function in the local namespace
216-
local_namespace = {}
217-
exec(transformed_code, globals(), local_namespace)
218-
219-
# Return the transformed function
220-
return local_namespace[func.__name__]
221-
222-
223-
def set_kernel_fn(func: Callable,
224-
independent_vars: List[str] = ["X", "Z"],
225-
jit_decorator: bool = True,
226-
docstring: Optional[str] = None) -> Callable:
227-
"""
228-
Transforms the given kernel function to use a params dictionary for its hyperparameters.
229-
The resultant function will always add jitter before returning the computed kernel.
230-
231-
Args:
232-
func (Callable): The kernel function to be transformed.
233-
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
234-
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
235-
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
236-
237-
Returns:
238-
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
239-
"""
240-
241-
# Extract parameter names excluding the independent variables
242-
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
243-
for var in independent_vars:
244-
params_names.remove(var)
245-
246-
transformed_code = ""
247-
if jit_decorator:
248-
transformed_code += "@jit" + "\n"
249-
250-
additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
251-
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"
252-
253-
if docstring:
254-
transformed_code += ' """' + docstring + '"""\n'
255-
256-
source = inspect.getsource(func).split("\n", 1)[1]
257-
lines = source.split("\n")
258-
259-
for idx, line in enumerate(lines):
260-
# Convert all parameter names to their dictionary lookup throughout the function body
261-
for name in params_names:
262-
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])
263-
264-
# Combine lines back and then split again by return
265-
modified_source = '\n'.join(lines)
266-
pre_return, return_statement = modified_source.split('return', 1)
267-
268-
# Append custom jitter code
269-
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
270-
custom_code += """
271-
if X.shape == Z.shape:
272-
k += (noise + jitter) * jnp.eye(X.shape[0])
273-
return k
274-
"""
275-
276-
transformed_code += custom_code
277-
278-
local_namespace = {"jit": jax.jit}
279-
exec(transformed_code, globals(), local_namespace)
280-
281-
return local_namespace[func.__name__]
282-
283-
284-
def _set_noise_kernel_fn(func: Callable) -> Callable:
285-
"""
286-
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
287-
288-
Args:
289-
func (Callable): Original function.
290-
291-
Returns:
292-
Callable: Modified function.
293-
"""
294-
295-
# Get the source code of the function
296-
source = inspect.getsource(func)
297-
298-
# Split the source into decorators, definition, and body
299-
decorators_and_def, body = source.split("\n", 1)
300-
301-
# Replace all occurrences of params["k with params["k_noise in the body
302-
modified_body = re.sub(r'params\["k', 'params["k_noise', body)
303-
304-
# Combine decorators, definition, and modified body
305-
modified_source = f"{decorators_and_def}\n{modified_body}"
306-
307-
# Define local namespace including the jit decorator
308-
local_namespace = {"jit": jax.jit}
309-
310-
# Execute the modified source to redefine the function in the provided namespace
311-
exec(modified_source, globals(), local_namespace)
312-
313-
# Return the modified function
314-
return local_namespace[func.__name__]
315-
316-
317182
def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
318183
"""
319184
Generates a function that, when invoked, samples from normal or log-normal distributions

gpax/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .utils import *
2-
from .priors import *
3-
from .priors import _set_noise_kernel_fn
2+
from .fn import *
3+
from .fn import _set_noise_kernel_fn

gpax/utils/fn.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
fn.py
3+
=====
4+
5+
Utilities for setting up custom mean and kernel functions
6+
7+
Created by Maxim Ziatdinov (email: [email protected])
8+
"""
9+
10+
import inspect
11+
import re
12+
13+
from typing import List, Callable, Optional
14+
15+
import jax
16+
17+
from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt
18+
19+
20+
def set_fn(func: Callable) -> Callable:
21+
"""
22+
Transforms the given deterministic function to use a params dictionary
23+
for its parameters, excluding the first one (assumed to be the dependent variable).
24+
25+
Args:
26+
- func (Callable): The deterministic function to be transformed.
27+
28+
Returns:
29+
- Callable: The transformed function where parameters are accessed
30+
from a `params` dictionary.
31+
"""
32+
# Extract parameter names excluding the first one (assumed to be the dependent variable)
33+
params_names = list(inspect.signature(func).parameters.keys())[1:]
34+
35+
# Create the transformed function definition
36+
transformed_code = f"def {func.__name__}(x, params):\n"
37+
38+
# Retrieve the source code of the function and indent it to be a valid function body
39+
source = inspect.getsource(func).split("\n", 1)[1]
40+
source = " " + source.replace("\n", "\n ")
41+
42+
# Replace each parameter name with its dictionary lookup using regex
43+
for name in params_names:
44+
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)
45+
46+
# Combine to get the full source
47+
transformed_code += source
48+
49+
# Define the transformed function in the local namespace
50+
local_namespace = {}
51+
exec(transformed_code, globals(), local_namespace)
52+
53+
# Return the transformed function
54+
return local_namespace[func.__name__]
55+
56+
57+
def set_kernel_fn(func: Callable,
58+
independent_vars: List[str] = ["X", "Z"],
59+
jit_decorator: bool = True,
60+
docstring: Optional[str] = None) -> Callable:
61+
"""
62+
Transforms the given kernel function to use a params dictionary for its hyperparameters.
63+
The resultant function will always add jitter before returning the computed kernel.
64+
65+
Args:
66+
func (Callable): The kernel function to be transformed.
67+
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
68+
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
69+
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
70+
71+
Returns:
72+
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
73+
"""
74+
75+
# Extract parameter names excluding the independent variables
76+
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
77+
for var in independent_vars:
78+
params_names.remove(var)
79+
80+
transformed_code = ""
81+
if jit_decorator:
82+
transformed_code += "@jit" + "\n"
83+
84+
additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
85+
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"
86+
87+
if docstring:
88+
transformed_code += ' """' + docstring + '"""\n'
89+
90+
source = inspect.getsource(func).split("\n", 1)[1]
91+
lines = source.split("\n")
92+
93+
for idx, line in enumerate(lines):
94+
# Convert all parameter names to their dictionary lookup throughout the function body
95+
for name in params_names:
96+
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])
97+
98+
# Combine lines back and then split again by return
99+
modified_source = '\n'.join(lines)
100+
pre_return, return_statement = modified_source.split('return', 1)
101+
102+
# Append custom jitter code
103+
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
104+
custom_code += """
105+
if X.shape == Z.shape:
106+
k += (noise + jitter) * jnp.eye(X.shape[0])
107+
return k
108+
"""
109+
110+
transformed_code += custom_code
111+
112+
local_namespace = {"jit": jax.jit}
113+
exec(transformed_code, globals(), local_namespace)
114+
115+
return local_namespace[func.__name__]
116+
117+
118+
def _set_noise_kernel_fn(func: Callable) -> Callable:
119+
"""
120+
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
121+
122+
Args:
123+
func (Callable): Original function.
124+
125+
Returns:
126+
Callable: Modified function.
127+
"""
128+
129+
# Get the source code of the function
130+
source = inspect.getsource(func)
131+
132+
# Split the source into decorators, definition, and body
133+
decorators_and_def, body = source.split("\n", 1)
134+
135+
# Replace all occurrences of params["k with params["k_noise in the body
136+
modified_body = re.sub(r'params\["k', 'params["k_noise', body)
137+
138+
# Combine decorators, definition, and modified body
139+
modified_source = f"{decorators_and_def}\n{modified_body}"
140+
141+
# Define local namespace including the jit decorator
142+
local_namespace = {"jit": jax.jit}
143+
144+
# Execute the modified source to redefine the function in the provided namespace
145+
exec(modified_source, globals(), local_namespace)
146+
147+
# Return the modified function
148+
return local_namespace[func.__name__]

0 commit comments

Comments
 (0)