Skip to content

Commit

Permalink
Update function evaluator (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hans-Martin von Gaudecker <[email protected]>
  • Loading branch information
3 people authored Jun 27, 2024
1 parent 133ed92 commit 59d81aa
Show file tree
Hide file tree
Showing 20 changed files with 1,980 additions and 994 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,17 @@ jobs:
- name: Run mypy
shell: bash -l {0}
run: pixi run mypy
run-explanation-notebooks:
name: Run explanation notebooks on Python 3.12
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected]
with:
pixi-version: v0.23.0
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test-cpu
- name: Run explanation notebooks
shell: bash -l {0}
run: pixi run -e test-cpu explanation-notebooks
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ repos:
- id: mdformat
additional_dependencies:
- mdformat-gfm
- mdformat-gfm-alerts
- mdformat-black
args:
- --wrap
Expand Down
14 changes: 9 additions & 5 deletions explanations/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Explanations of `lcm` concepts
# Explanations of internal `lcm` concepts

## Choose a module
> [!NOTE]
> 1. The following explanations are designed for `lcm` developers and not users.
> 1. Figures are only rendered correctly on nbviewer, not on GitHub. Please use the
> links below to view the correctly rendered notebooks.
| Module name | Description |
| --------------------------------------- | ----------------------------------------------------------------- |
| [`dispatchers.py`](./dispatchers.ipynb) | Explanations of `spacemap`, `productmap`, and `vmap_1d` functions |
| Module name | Description |
| ---------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- |
| [`dispatchers.py`](https://nbviewer.org/github/OpenSourceEconomics/lcm/blob/main/explanations/dispatchers.ipynb) | Explanations of `spacemap`, `productmap`, and `vmap_1d` functions |
| [`function_representation.py`](https://nbviewer.org/github/OpenSourceEconomics/lcm/blob/main/explanations/function_representation.ipynb) | Explanations of what the function representation does and how it works |
29 changes: 10 additions & 19 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import pytest\n",
"from jax import vmap\n",
"from lcm.dispatchers import productmap, spacemap, vmap_1d"
]
Expand All @@ -27,7 +28,7 @@
"source": [
"# `vmap_1d`\n",
"\n",
"Let's vectorizing the function `f` over axis `a`."
"Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function."
]
},
{
Expand Down Expand Up @@ -98,23 +99,13 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=2, len(args)=0",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mf_vmapped\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"File \u001b[0;32m~/miniforge3/envs/lcm/lib/python3.11/site-packages/jax/_src/api.py:1259\u001b[0m, in \u001b[0;36mvmap.<locals>.vmap_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1255\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(fun, docstr\u001b[38;5;241m=\u001b[39mdocstr)\n\u001b[1;32m 1256\u001b[0m \u001b[38;5;129m@api_boundary\u001b[39m\n\u001b[1;32m 1257\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvmap_f\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1258\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(in_axes, \u001b[38;5;28mtuple\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(in_axes) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(args):\n\u001b[0;32m-> 1259\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvmap in_axes must be an int, None, or a tuple of entries corresponding \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto the positional arguments passed to the function, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1261\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(in_axes)\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(args)\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1262\u001b[0m args_flat, in_tree \u001b[38;5;241m=\u001b[39m tree_flatten((args, kwargs), is_leaf\u001b[38;5;241m=\u001b[39mbatching\u001b[38;5;241m.\u001b[39mis_vmappable)\n\u001b[1;32m 1263\u001b[0m f \u001b[38;5;241m=\u001b[39m lu\u001b[38;5;241m.\u001b[39mwrap_init(fun)\n",
"\u001b[0;31mValueError\u001b[0m: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=2, len(args)=0"
]
}
],
"outputs": [],
"source": [
"f_vmapped(a=a, b=1)"
"with pytest.raises(\n",
" ValueError,\n",
" match=\"vmap in_axes must be an int, None, or a tuple of entries corresponding to\",\n",
"):\n",
" f_vmapped(a=a, b=1)"
]
},
{
Expand Down Expand Up @@ -255,7 +246,7 @@
"The `spacemap` function combines `productmap` and `vmap_1d` in a way that is often\n",
"needed in `lcm`.\n",
"\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"\n",
"Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n",
"\n",
Expand Down Expand Up @@ -818,7 +809,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.0"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 59d81aa

Please sign in to comment.