diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 94a78581..f0420fe4 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.0 with: - pixi-version: v0.23.0 + pixi-version: v0.29.0 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: test-cpu @@ -48,7 +48,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.0 with: - pixi-version: v0.23.0 + pixi-version: v0.29.0 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: mypy @@ -62,7 +62,7 @@ jobs: - uses: actions/checkout@v4 - uses: prefix-dev/setup-pixi@v0.8.0 with: - pixi-version: v0.23.0 + pixi-version: v0.29.0 cache: true cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} environments: test-cpu diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f09e552b..c6c5559a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/lyz-code/yamlfix - rev: 1.16.0 + rev: 1.17.0 hooks: - id: yamlfix - repo: https://github.com/pre-commit/pre-commit-hooks @@ -46,7 +46,7 @@ repos: hooks: - id: yamllint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.2 + rev: v0.6.5 hooks: # Run the linter. - id: ruff @@ -74,5 +74,26 @@ repos: - --wrap - '88' files: (README\.md) + - repo: https://github.com/kynan/nbstripout + rev: 0.7.1 + hooks: + - id: nbstripout + args: + - --drop-empty-cells + - --keep-output + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + files: src|tests + additional_dependencies: + - jax>=0.4.20 + - numpy + - packaging + - pandas-stubs + args: + - --config=pyproject.toml ci: autoupdate_schedule: monthly + skip: + - mypy # installing jax is not possible on pre-commit.ci due to size limits. diff --git a/codecov.yml b/codecov.yml index 740448d8..c01f5dab 100644 --- a/codecov.yml +++ b/codecov.yml @@ -14,5 +14,4 @@ coverage: default: target: 90% ignore: - - setup.py - tests/**/* diff --git a/examples/long_running.py b/examples/long_running.py index 52596fb1..9c54a5f0 100644 --- a/examples/long_running.py +++ b/examples/long_running.py @@ -1,19 +1,8 @@ """Example specification for a consumption-savings model with health and exercise.""" import jax.numpy as jnp -from lcm import DiscreteGrid, LinspaceGrid, Model - -# ====================================================================================== -# Numerical parameters and constants -# ====================================================================================== -N_GRID_POINTS = { - "wealth": 100, - "health": 100, - "consumption": 100, - "exercise": 200, -} -RETIREMENT_AGE = 65 +from lcm import DiscreteGrid, LinspaceGrid, Model # ====================================================================================== # Model functions @@ -63,6 +52,8 @@ def consumption_constraint(consumption, wealth, labor_income): # ====================================================================================== # Model specification and parameters # ====================================================================================== +RETIREMENT_AGE = 65 + MODEL_CONFIG = Model( n_periods=RETIREMENT_AGE - 18, @@ -80,24 +71,24 @@ def consumption_constraint(consumption, wealth, labor_income): "consumption": LinspaceGrid( start=1, stop=100, - n_points=N_GRID_POINTS["consumption"], + n_points=100, ), "exercise": LinspaceGrid( start=0, stop=1, - n_points=N_GRID_POINTS["exercise"], + n_points=200, ), }, states={ "wealth": LinspaceGrid( start=1, stop=100, - n_points=N_GRID_POINTS["wealth"], + n_points=100, ), "health": LinspaceGrid( start=0, stop=1, - n_points=N_GRID_POINTS["health"], + n_points=100, ), }, ) diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb index f744223e..81aae290 100644 --- a/explanations/dispatchers.ipynb +++ b/explanations/dispatchers.ipynb @@ -12,13 +12,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "import pytest\n", "from jax import vmap\n", + "\n", "from lcm.dispatchers import productmap, spacemap, vmap_1d" ] }, @@ -33,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -52,7 +53,7 @@ "Array([1. , 1.25, 1.5 , 1.75, 2. ], dtype=float32)" ] }, - "execution_count": 3, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -77,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -86,7 +87,7 @@ "Array([1. , 1.25, 1.5 , 1.75, 2. ], dtype=float32)" ] }, - "execution_count": 4, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -135,7 +136,7 @@ "Array([1. , 1.25, 1.5 , 1.75, 2. ], dtype=float32)" ] }, - "execution_count": 7, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -156,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -166,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -178,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -187,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -202,7 +203,7 @@ " [ 2, 3, 4, 5]]], dtype=int32)" ] }, - "execution_count": 11, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -214,7 +215,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -223,7 +224,7 @@ "(2, 3, 4)" ] }, - "execution_count": 12, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -269,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -307,11 +308,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from lcm.process_model import process_model\n", + "from lcm.input_processing import process_model\n", "from lcm.state_space import create_state_choice_space\n", "\n", "processed_model = process_model(model)\n", @@ -335,7 +336,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -344,7 +345,7 @@ "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, - "execution_count": 15, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -355,7 +356,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -365,7 +366,7 @@ " 'retirement': Array([0, 1, 1], dtype=int32)}" ] }, - "execution_count": 16, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -376,7 +377,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -431,7 +432,7 @@ "2 1 1" ] }, - "execution_count": 17, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -497,7 +498,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -506,7 +507,7 @@ "{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}" ] }, - "execution_count": 18, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -539,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -553,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -562,7 +563,7 @@ "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, - "execution_count": 20, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -573,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -583,7 +584,7 @@ " 'retirement': Array([0, 1, 1], dtype=int32)}" ] }, - "execution_count": 21, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -594,7 +595,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -605,7 +606,7 @@ " [ 1. , 2. , 3. , 4. ]], dtype=float32)" ] }, - "execution_count": 22, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -621,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -630,7 +631,7 @@ "(3, 4)" ] }, - "execution_count": 23, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -648,7 +649,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -659,7 +660,7 @@ " [ 1. , 2. , 3. , 4. ]], dtype=float32)" ] }, - "execution_count": 24, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -725,7 +726,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -734,7 +735,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -749,7 +750,7 @@ " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)" ] }, - "execution_count": 26, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -765,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -774,7 +775,7 @@ "(2, 3, 4)" ] }, - "execution_count": 27, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } diff --git a/explanations/function_representation.ipynb b/explanations/function_representation.ipynb index 11f6cf33..0e89ddb5 100644 --- a/explanations/function_representation.ipynb +++ b/explanations/function_representation.ipynb @@ -72,11 +72,12 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", + "\n", "from lcm import DiscreteGrid, LinspaceGrid, Model\n", "\n", "\n", @@ -152,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -164,7 +165,7 @@ } ], "source": [ - "from lcm.process_model import process_model\n", + "from lcm.input_processing import process_model\n", "\n", "processed_model = process_model(model)" ] @@ -182,7 +183,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -191,7 +192,7 @@ "" ] }, - "execution_count": 3, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -215,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -247,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -256,7 +257,7 @@ "dict_keys(['retirement', 'wealth', 'consumption'])" ] }, - "execution_count": 5, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -267,7 +268,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -299,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -308,7 +309,7 @@ "(10,)" ] }, - "execution_count": 7, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -320,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -329,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -359,9 +360,9 @@ { "data": { "text/html": [ - "