|
4 | 4 |
|
5 | 5 | Utility functions for setting priors
|
6 | 6 |
|
7 |
| -Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) |
| 7 | +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) |
8 | 8 | """
|
9 | 9 |
|
10 | 10 | import inspect
|
11 |
| -import re |
12 | 11 |
|
13 |
| -from typing import Union, Dict, Type, List, Callable, Optional |
| 12 | +from typing import Union, Dict, Type, Callable |
14 | 13 |
|
15 | 14 | import numpyro
|
16 |
| -import jax |
17 | 15 | import jax.numpy as jnp
|
18 | 16 |
|
19 |
| -from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt |
20 |
| - |
21 | 17 |
|
22 | 18 | def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
|
23 | 19 | """
|
@@ -183,137 +179,6 @@ def uniform_dist(low: float = None,
|
183 | 179 | return numpyro.distributions.Uniform(low, high)
|
184 | 180 |
|
185 | 181 |
|
186 |
| -def set_fn(func: Callable) -> Callable: |
187 |
| - """ |
188 |
| - Transforms the given deterministic function to use a params dictionary |
189 |
| - for its parameters, excluding the first one (assumed to be the dependent variable). |
190 |
| -
|
191 |
| - Args: |
192 |
| - - func (Callable): The deterministic function to be transformed. |
193 |
| -
|
194 |
| - Returns: |
195 |
| - - Callable: The transformed function where parameters are accessed |
196 |
| - from a `params` dictionary. |
197 |
| - """ |
198 |
| - # Extract parameter names excluding the first one (assumed to be the dependent variable) |
199 |
| - params_names = list(inspect.signature(func).parameters.keys())[1:] |
200 |
| - |
201 |
| - # Create the transformed function definition |
202 |
| - transformed_code = f"def {func.__name__}(x, params):\n" |
203 |
| - |
204 |
| - # Retrieve the source code of the function and indent it to be a valid function body |
205 |
| - source = inspect.getsource(func).split("\n", 1)[1] |
206 |
| - source = " " + source.replace("\n", "\n ") |
207 |
| - |
208 |
| - # Replace each parameter name with its dictionary lookup using regex |
209 |
| - for name in params_names: |
210 |
| - source = re.sub(rf'\b{name}\b', f'params["{name}"]', source) |
211 |
| - |
212 |
| - # Combine to get the full source |
213 |
| - transformed_code += source |
214 |
| - |
215 |
| - # Define the transformed function in the local namespace |
216 |
| - local_namespace = {} |
217 |
| - exec(transformed_code, globals(), local_namespace) |
218 |
| - |
219 |
| - # Return the transformed function |
220 |
| - return local_namespace[func.__name__] |
221 |
| - |
222 |
| - |
223 |
| -def set_kernel_fn(func: Callable, |
224 |
| - independent_vars: List[str] = ["X", "Z"], |
225 |
| - jit_decorator: bool = True, |
226 |
| - docstring: Optional[str] = None) -> Callable: |
227 |
| - """ |
228 |
| - Transforms the given kernel function to use a params dictionary for its hyperparameters. |
229 |
| - The resultant function will always add jitter before returning the computed kernel. |
230 |
| -
|
231 |
| - Args: |
232 |
| - func (Callable): The kernel function to be transformed. |
233 |
| - independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"]. |
234 |
| - jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True. |
235 |
| - docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None. |
236 |
| -
|
237 |
| - Returns: |
238 |
| - Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary. |
239 |
| - """ |
240 |
| - |
241 |
| - # Extract parameter names excluding the independent variables |
242 |
| - params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty] |
243 |
| - for var in independent_vars: |
244 |
| - params_names.remove(var) |
245 |
| - |
246 |
| - transformed_code = "" |
247 |
| - if jit_decorator: |
248 |
| - transformed_code += "@jit" + "\n" |
249 |
| - |
250 |
| - additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs" |
251 |
| - transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n" |
252 |
| - |
253 |
| - if docstring: |
254 |
| - transformed_code += ' """' + docstring + '"""\n' |
255 |
| - |
256 |
| - source = inspect.getsource(func).split("\n", 1)[1] |
257 |
| - lines = source.split("\n") |
258 |
| - |
259 |
| - for idx, line in enumerate(lines): |
260 |
| - # Convert all parameter names to their dictionary lookup throughout the function body |
261 |
| - for name in params_names: |
262 |
| - lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx]) |
263 |
| - |
264 |
| - # Combine lines back and then split again by return |
265 |
| - modified_source = '\n'.join(lines) |
266 |
| - pre_return, return_statement = modified_source.split('return', 1) |
267 |
| - |
268 |
| - # Append custom jitter code |
269 |
| - custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n" |
270 |
| - custom_code += """ |
271 |
| - if X.shape == Z.shape: |
272 |
| - k += (noise + jitter) * jnp.eye(X.shape[0]) |
273 |
| - return k |
274 |
| - """ |
275 |
| - |
276 |
| - transformed_code += custom_code |
277 |
| - |
278 |
| - local_namespace = {"jit": jax.jit} |
279 |
| - exec(transformed_code, globals(), local_namespace) |
280 |
| - |
281 |
| - return local_namespace[func.__name__] |
282 |
| - |
283 |
| - |
284 |
| -def _set_noise_kernel_fn(func: Callable) -> Callable: |
285 |
| - """ |
286 |
| - Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses. |
287 |
| -
|
288 |
| - Args: |
289 |
| - func (Callable): Original function. |
290 |
| -
|
291 |
| - Returns: |
292 |
| - Callable: Modified function. |
293 |
| - """ |
294 |
| - |
295 |
| - # Get the source code of the function |
296 |
| - source = inspect.getsource(func) |
297 |
| - |
298 |
| - # Split the source into decorators, definition, and body |
299 |
| - decorators_and_def, body = source.split("\n", 1) |
300 |
| - |
301 |
| - # Replace all occurrences of params["k with params["k_noise in the body |
302 |
| - modified_body = re.sub(r'params\["k', 'params["k_noise', body) |
303 |
| - |
304 |
| - # Combine decorators, definition, and modified body |
305 |
| - modified_source = f"{decorators_and_def}\n{modified_body}" |
306 |
| - |
307 |
| - # Define local namespace including the jit decorator |
308 |
| - local_namespace = {"jit": jax.jit} |
309 |
| - |
310 |
| - # Execute the modified source to redefine the function in the provided namespace |
311 |
| - exec(modified_source, globals(), local_namespace) |
312 |
| - |
313 |
| - # Return the modified function |
314 |
| - return local_namespace[func.__name__] |
315 |
| - |
316 |
| - |
317 | 182 | def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
|
318 | 183 | """
|
319 | 184 | Generates a function that, when invoked, samples from normal or log-normal distributions
|
|
0 commit comments