Skip to content

Commit

Permalink
Test [1, 0] discrete grid case
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 19, 2024
1 parent 8a28ceb commit 227b859
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,21 @@ def next_c(a, b):
"a": DiscreteGrid(category_class_factory([0, 1])),
},
states={
"c": DiscreteGrid(category_class_factory([1, 10])),
# unsorted indices cannot be treated as indices
"c": DiscreteGrid(category_class_factory([1, 0])),
},
)


def test_get_index_to_label_func():
codes_array = jnp.array([1, 10])
codes_array = jnp.array([1, 0])
got = _get_index_to_code_func(codes_array, name="foo")
assert got(__foo_index__=0) == 1
assert got(1) == 10
assert got(1) == 0


def test_get_code_to_index_func():
codes_array = jnp.array([1, 10])
codes_array = jnp.array([1, 0])
got = _get_code_to_index_func(codes_array, name="foo")
assert_array_equal(got(foo=codes_array), jnp.arange(2))

Expand All @@ -78,4 +79,4 @@ def test_convert_discrete_codes_to_indices(model):
assert got.states["__c_index__"].categories == ["__cat0_index__", "__cat1_index__"]
assert got.states["__c_index__"].codes == [0, 1]
assert got.functions["c"](0) == 1
assert got.functions["c"](1) == 10
assert got.functions["c"](1) == 0

0 comments on commit 227b859

Please sign in to comment.