diff --git a/src/lcm/grids.py b/src/lcm/grids.py index 9063eb5..9ea9843 100644 --- a/src/lcm/grids.py +++ b/src/lcm/grids.py @@ -198,9 +198,6 @@ def _validate_discrete_grid(category_class: type) -> None: if error_messages: msg = format_messages(error_messages) raise GridInitializationError(msg) - if error_messages: - msg = format_messages(error_messages) - raise GridInitializationError(msg) def _get_field_names_and_values(dc: type) -> dict[str, Any]: diff --git a/tests/input_processing/test_discrete_state_conversion.py b/tests/input_processing/test_discrete_state_conversion.py index d164e51..d0bb282 100644 --- a/tests/input_processing/test_discrete_state_conversion.py +++ b/tests/input_processing/test_discrete_state_conversion.py @@ -11,6 +11,7 @@ _get_code_to_index_func, _get_discrete_vars_with_non_index_codes, _get_index_to_code_func, + _get_next_index_func, convert_arbitrary_codes_to_array_indices, ) @@ -147,3 +148,13 @@ def test_convert_discrete_codes_to_indices(model): assert got.states["__c_index__"].codes == (0, 1) assert got.functions["c"](0) == 1 assert got.functions["c"](1) == 0 + + +def test_get_next_index_func(): + got_func = _get_next_index_func( + next_func=lambda wealth, working: jnp.clip(wealth + working, 1, 3), + codes_array=jnp.array([1, 2, 3]), + name="wealth", + ) + got = got_func(wealth=jnp.array([1, 2]), working=jnp.array([0, 1])) + assert_array_equal(got, jnp.array([0, 2]))