Skip to content

Commit 67a084b

Browse files
kurtamohlervmoens
authored andcommitted
[Test] Improve coverage of ChessEnv.all_actions
ghstack-source-id: f8d40f6 Pull Request resolved: #2849
1 parent efe9389 commit 67a084b

File tree

1 file changed

+41
-20
lines changed

1 file changed

+41
-20
lines changed

test/test_env.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,43 +4234,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
42344234
td_check = env.reset(td.select("fen_hash"))
42354235
assert (td_check == td).all()
42364236

4237-
@pytest.mark.parametrize("include_fen", [False, True])
4238-
@pytest.mark.parametrize("include_pgn", [False, True])
4237+
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
42394238
@pytest.mark.parametrize("stateful", [False, True])
4240-
@pytest.mark.parametrize("mask_actions", [False, True])
4241-
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4242-
if not stateful and not include_fen and not include_pgn:
4243-
# pytest.skip("fen or pgn must be included if not stateful")
4244-
return
4245-
4239+
@pytest.mark.parametrize("include_hash", [False, True])
4240+
@pytest.mark.parametrize("include_san", [False, True])
4241+
@pytest.mark.parametrize("append_transform", [False, True])
4242+
@pytest.mark.parametrize("mask_actions", [True])
4243+
def test_all_actions(
4244+
self,
4245+
include_fen,
4246+
include_pgn,
4247+
stateful,
4248+
include_hash,
4249+
include_san,
4250+
append_transform,
4251+
mask_actions,
4252+
):
42464253
env = ChessEnv(
42474254
include_fen=include_fen,
42484255
include_pgn=include_pgn,
4256+
include_san=include_san,
4257+
include_hash=include_hash,
4258+
include_hash_inv=include_hash,
42494259
stateful=stateful,
42504260
mask_actions=mask_actions,
42514261
)
4252-
td = env.reset()
42534262

4254-
if not mask_actions:
4255-
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
4256-
env.all_actions()
4257-
return
4263+
def transform_reward(td):
4264+
if "reward" not in td:
4265+
return td
4266+
reward = td["reward"]
4267+
if reward == 0.5:
4268+
td["reward"] = 0
4269+
elif reward == 1 and td["turn"]:
4270+
td["reward"] = -td["reward"]
4271+
return td
4272+
4273+
if append_transform:
4274+
env = env.append_transform(transform_reward)
4275+
4276+
check_env_specs(env)
4277+
4278+
td = env.reset()
42584279

42594280
# Choose random actions from the output of `all_actions`
4260-
for _ in range(100):
4261-
if stateful:
4262-
all_actions = env.all_actions()
4263-
else:
4281+
for step_idx in range(100):
4282+
if step_idx % 5 == 0:
42644283
# Reset theinitial state first, just to make sure
42654284
# `all_actions` knows how to get the board state from the input.
42664285
env.reset()
4267-
all_actions = env.all_actions(td.clone())
4286+
all_actions = env.all_actions(td.clone())
42684287

42694288
# Choose some random actions and make sure they match exactly one of
42704289
# the actions from `all_actions`. This part is not tested when
42714290
# `mask_actions == False`, because `rand_action` can pick illegal
42724291
# actions in that case.
4273-
if mask_actions:
4292+
if mask_actions and step_idx % 4 == 0:
42744293
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
42754294
# it fail to work properly for stateless mode. It doesn't know
42764295
# how to correctly reset the board state to what is given in the
@@ -4287,7 +4306,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
42874306

42884307
action_idx = torch.randint(0, all_actions.shape[0], ()).item()
42894308
chosen_action = all_actions[action_idx]
4290-
td = env.step(td.update(chosen_action))["next"]
4309+
td_new = env.step(td.update(chosen_action).clone())
4310+
assert (td == td_new.exclude("next")).all()
4311+
td = td_new["next"]
42914312

42924313
if td["done"]:
42934314
td = env.reset()

0 commit comments

Comments
 (0)