Skip to content

Commit

Permalink
Merge branch 'main' into long-running-test
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Feb 13, 2024
2 parents 01a4461 + e72f3ba commit a761e8e
Show file tree
Hide file tree
Showing 28 changed files with 260 additions and 189 deletions.
9 changes: 1 addition & 8 deletions .envs/testenv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ dependencies:
- setuptools_scm
- toml

# Package dependencies
- dags
- jax>=0.4.10
- jaxlib>=0.4.10
- numpy
- pandas

# Testing dependencies
- scipy
- pybaum
Expand All @@ -24,6 +17,6 @@ dependencies:
- pytest-cov
- pytest-xdist

# Install lcm locally
# Install lcm and its dependencies locally
- pip:
- -e ../
5 changes: 2 additions & 3 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ jobs:
environment-file: ./.envs/testenv.yml
environment-name: lcm
cache-environment: true
create-args: >
extra-specs: |
create-args: >-
python=${{ matrix.python-version }}
- name: run pytest
shell: bash -l {0}
run: |
micromamba activate lcm
pytest --cov-report=xml --cov=./
- name: Upload coverage report.
if: runner.os == 'Linux' && matrix.python-version == '3.10'
if: runner.os == 'Linux' && matrix.python-version == '3.11'
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/publish-to-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.10
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
- name: Install pypa/build
run: >-
python -m
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: yamllint
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 24.1.1
hooks:
- id: black
language_version: python3.11
Expand All @@ -54,13 +54,13 @@ repos:
hooks:
- id: blacken-docs
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.6
rev: v0.2.0
hooks:
- id: ruff
# args:
# - --verbose
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
args:
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
documentation root, use os.path.abspath to make it absolute, like shown here.
"""

import pathlib
import sys

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ extend-ignore = [
"TRY003",
]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"docs/source/conf.py" = ["E501", "ERA001", "DTZ005"]
"tests/test_*.py" = ["PLR2004"]

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "google"


Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ keywords =
packages = find:
install_requires =
dags
jax>=0.3.0
jaxlib>=0.3.0
jax>=0.4.10
jaxlib>=0.4.10
numpy
pandas
python_requires = >=3.10
python_requires = >=3.11
include_package_data = True
package_dir =
=src
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Setup script for the package."""

from setuptools import setup

if __name__ == "__main__":
Expand Down
15 changes: 8 additions & 7 deletions src/lcm/create_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Create a parameters for a model specification."""

import inspect

import numpy as np
Expand Down Expand Up @@ -102,12 +103,12 @@ def _create_shock_params(model, variable_info, grids):
inspect.signature(model["functions"][f"next_{var}"]).parameters,
)

_check_variables_are_all_discrete_states(
_check_variables_are_all_discrete_or_period(
variables=dependencies,
variable_info=variable_info,
msg_suffix=(
f"The function next_{var} can only depend on discrete state variables "
f"or '_period'."
f"The function next_{var} can only depend on discrete variables or"
f"'_period'."
),
)

Expand All @@ -128,10 +129,10 @@ def _create_standard_params():
return {"beta": np.nan}


def _check_variables_are_all_discrete_states(variables, variable_info, msg_suffix):
discrete_state_vars = variable_info.query("is_state and is_discrete").index.tolist()
def _check_variables_are_all_discrete_or_period(variables, variable_info, msg_suffix):
discrete_vars = variable_info.query("is_discrete").index.tolist()
for var in variables:
if var not in discrete_state_vars and var != "_period":
if var not in discrete_vars and var != "_period":
raise ValueError(
f"Variable {var} is not a discrete state variable. {msg_suffix}",
f"Variable {var} is not a discrete variable. {msg_suffix}",
)
1 change: 1 addition & 0 deletions src/lcm/discrete_emax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
https://github.com/google/jax/issues/6265
"""

from functools import partial

import jax
Expand Down
1 change: 1 addition & 0 deletions src/lcm/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
in JAX.
"""

import numpy as np


Expand Down
1 change: 1 addition & 0 deletions src/lcm/example_models/example_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define example model specifications."""

import jax.numpy as jnp

RETIREMENT_AGE = 65
Expand Down
71 changes: 63 additions & 8 deletions src/lcm/example_models/example_models_stochastic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Define example model specifications."""

import jax.numpy as jnp

import lcm
Expand All @@ -21,7 +22,7 @@ def next_health(health, partner): # noqa: ARG001


@lcm.mark.stochastic
def next_partner(_period):
def next_partner(_period, working, partner): # noqa: ARG001
pass


Expand All @@ -41,7 +42,7 @@ def consumption_constraint(consumption, wealth):
"working": {"options": [0, 1]},
"consumption": {
"grid_type": "linspace",
"start": 0,
"start": 1,
"stop": 100,
"n_points": N_CHOICE_GRID_POINTS,
},
Expand All @@ -51,7 +52,7 @@ def consumption_constraint(consumption, wealth):
"partner": {"options": [0, 1]},
"wealth": {
"grid_type": "linspace",
"start": 0,
"start": 1,
"stop": 100,
"n_points": N_STATE_GRID_POINTS,
},
Expand All @@ -61,13 +62,67 @@ def consumption_constraint(consumption, wealth):


PARAMS = {
"beta": 0.25,
"utility": {"delta": 0.25, "gamma": 0.25},
"next_wealth": {"interest_rate": 0.25, "wage": 0.25},
"beta": 0.95,
"utility": {"delta": 0.5, "gamma": 0.25},
"next_wealth": {"interest_rate": 0.05, "wage": 10.0},
"next_health": {},
"consumption_constraint": {},
"shocks": {
"health": jnp.array([[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]]]),
"partner": jnp.array([[1.0, 0], [0.0, 1]]),
# Health shock:
# ------------------------------------------------------------------------------
# 1st dimension: Current health state
# 2nd dimension: Current Partner state
# 3rd dimension: Probability distribution over next period's health state
"health": jnp.array(
[
# Current health state 0
[
# Current Partner state 0
[0.9, 0.1],
# Current Partner state 1
[0.5, 0.5],
],
# Current health state 1
[
# Current Partner state 0
[0.5, 0.5],
# Current Partner state 1
[0.1, 0.9],
],
],
),
# Partner shock:
# ------------------------------------------------------------------------------
# 1st dimension: The period
# 2nd dimension: Current working decision
# 3rd dimension: Current partner state
# 4th dimension: Probability distribution over next period's partner state
"partner": jnp.array(
[
# Transition from period 0 to period 1
[
# Current working decision 0
[
# Current partner state 0
[0, 1.0],
# Current partner state 1
[1.0, 0],
],
# Current working decision 1
[
# Current partner state 0
[0, 1.0],
# Current partner state 1
[0.0, 1.0],
],
],
# Transition from period 1 to period 2
[
# Description is the same as above
[[0, 1.0], [1.0, 0]],
[[0, 1.0], [0.0, 1.0]],
],
],
),
},
}
16 changes: 10 additions & 6 deletions src/lcm/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def allow_kwargs(func):

# Create new signature without positional-only arguments
new_parameters = [
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
if p.kind == inspect.Parameter.POSITIONAL_ONLY
else p
(
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
if p.kind == inspect.Parameter.POSITIONAL_ONLY
else p
)
for p in parameters.values()
]
new_signature = signature.replace(parameters=new_parameters)
Expand Down Expand Up @@ -117,9 +119,11 @@ def allow_args(func):

# Create new signature without keyword-only arguments
new_parameters = [
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
if p.kind == inspect.Parameter.KEYWORD_ONLY
else p
(
p.replace(kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
if p.kind == inspect.Parameter.KEYWORD_ONLY
else p
)
for p in parameters.values()
]
new_signature = signature.replace(parameters=new_parameters)
Expand Down
1 change: 1 addition & 0 deletions src/lcm/get_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Get a user model and parameters."""

from typing import NamedTuple

from pybaum import tree_update
Expand Down
1 change: 1 addition & 0 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
it easy to call functions interchangeably.
"""

import jax.numpy as jnp


Expand Down
1 change: 1 addition & 0 deletions src/lcm/mark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Collection of LCM marking decorators."""

import functools
from typing import NamedTuple

Expand Down
7 changes: 6 additions & 1 deletion src/lcm/model_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def u_and_f(*args, **kwargs):
_period=period,
params=kwargs["params"],
)
weights = next_weights(**states, _period=period, params=kwargs["params"])
weights = next_weights(
**states,
**choices,
_period=period,
params=kwargs["params"],
)

value_function = productmap(
scalar_value_function,
Expand Down
1 change: 1 addition & 0 deletions src/lcm/next_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
deteministic next states.
"""

from dags import concatenate_functions
from dags.signature import with_signature

Expand Down
8 changes: 4 additions & 4 deletions src/lcm/process_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,12 @@ def _get_stochastic_weight_function(raw_func, name, variable_info, grids):
"""
function_parameters = list(inspect.signature(raw_func).parameters)

# Assert that stochastic next function only depends on state variables
# Assert that stochastic next function only depends on discrete variables or period
for arg in function_parameters:
if arg != "_period" and not variable_info.loc[arg, "is_state"]:
if arg != "_period" and not variable_info.loc[arg, "is_discrete"]:
raise ValueError(
f"Stochastic variables can only depend on state variables and '_period'"
f" but {name} depends on {arg}.",
f"Stochastic variables can only depend on discrete variables and "
f"'_period', but {name} depends on {arg}.",
)

label_translators = {
Expand Down
Loading

0 comments on commit a761e8e

Please sign in to comment.