Skip to content

Commit

Permalink
Some fixes and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 20, 2024
1 parent 2cf15b0 commit b1264ec
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
3 changes: 0 additions & 3 deletions src/lcm/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ def _validate_discrete_grid(category_class: type) -> None:
if error_messages:
msg = format_messages(error_messages)
raise GridInitializationError(msg)
if error_messages:
msg = format_messages(error_messages)
raise GridInitializationError(msg)


def _get_field_names_and_values(dc: type) -> dict[str, Any]:
Expand Down
11 changes: 11 additions & 0 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_get_code_to_index_func,
_get_discrete_vars_with_non_index_codes,
_get_index_to_code_func,
_get_next_index_func,
convert_arbitrary_codes_to_array_indices,
)

Expand Down Expand Up @@ -147,3 +148,13 @@ def test_convert_discrete_codes_to_indices(model):
assert got.states["__c_index__"].codes == (0, 1)
assert got.functions["c"](0) == 1
assert got.functions["c"](1) == 0


def test_get_next_index_func():
got_func = _get_next_index_func(
next_func=lambda wealth, working: jnp.clip(wealth + working, 1, 3),
codes_array=jnp.array([1, 2, 3]),
name="wealth",
)
got = got_func(wealth=jnp.array([1, 2]), working=jnp.array([0, 1]))
assert_array_equal(got, jnp.array([0, 2]))

0 comments on commit b1264ec

Please sign in to comment.