@@ -4234,43 +4234,62 @@ def test_env_reset_with_hash(self, stateful, include_san):
4234
4234
td_check = env .reset (td .select ("fen_hash" ))
4235
4235
assert (td_check == td ).all ()
4236
4236
4237
- @pytest .mark .parametrize ("include_fen" , [False , True ])
4238
- @pytest .mark .parametrize ("include_pgn" , [False , True ])
4237
+ @pytest .mark .parametrize ("include_fen,include_pgn" , [[False , True ], [True , False ]])
4239
4238
@pytest .mark .parametrize ("stateful" , [False , True ])
4240
- @pytest .mark .parametrize ("mask_actions" , [False , True ])
4241
- def test_all_actions (self , include_fen , include_pgn , stateful , mask_actions ):
4242
- if not stateful and not include_fen and not include_pgn :
4243
- # pytest.skip("fen or pgn must be included if not stateful")
4244
- return
4245
-
4239
+ @pytest .mark .parametrize ("include_hash" , [False , True ])
4240
+ @pytest .mark .parametrize ("include_san" , [False , True ])
4241
+ @pytest .mark .parametrize ("append_transform" , [False , True ])
4242
+ @pytest .mark .parametrize ("mask_actions" , [True ])
4243
+ def test_all_actions (
4244
+ self ,
4245
+ include_fen ,
4246
+ include_pgn ,
4247
+ stateful ,
4248
+ include_hash ,
4249
+ include_san ,
4250
+ append_transform ,
4251
+ mask_actions ,
4252
+ ):
4246
4253
env = ChessEnv (
4247
4254
include_fen = include_fen ,
4248
4255
include_pgn = include_pgn ,
4256
+ include_san = include_san ,
4257
+ include_hash = include_hash ,
4258
+ include_hash_inv = include_hash ,
4249
4259
stateful = stateful ,
4250
4260
mask_actions = mask_actions ,
4251
4261
)
4252
- td = env .reset ()
4253
4262
4254
- if not mask_actions :
4255
- with pytest .raises (RuntimeError , match = "Cannot generate legal actions" ):
4256
- env .all_actions ()
4257
- return
4263
+ def transform_reward (td ):
4264
+ if "reward" not in td :
4265
+ return td
4266
+ reward = td ["reward" ]
4267
+ if reward == 0.5 :
4268
+ td ["reward" ] = 0
4269
+ elif reward == 1 and td ["turn" ]:
4270
+ td ["reward" ] = - td ["reward" ]
4271
+ return td
4272
+
4273
+ if append_transform :
4274
+ env = env .append_transform (transform_reward )
4275
+
4276
+ check_env_specs (env )
4277
+
4278
+ td = env .reset ()
4258
4279
4259
4280
# Choose random actions from the output of `all_actions`
4260
- for _ in range (100 ):
4261
- if stateful :
4262
- all_actions = env .all_actions ()
4263
- else :
4281
+ for step_idx in range (100 ):
4282
+ if step_idx % 5 == 0 :
4264
4283
# Reset theinitial state first, just to make sure
4265
4284
# `all_actions` knows how to get the board state from the input.
4266
4285
env .reset ()
4267
- all_actions = env .all_actions (td .clone ())
4286
+ all_actions = env .all_actions (td .clone ())
4268
4287
4269
4288
# Choose some random actions and make sure they match exactly one of
4270
4289
# the actions from `all_actions`. This part is not tested when
4271
4290
# `mask_actions == False`, because `rand_action` can pick illegal
4272
4291
# actions in that case.
4273
- if mask_actions :
4292
+ if mask_actions and step_idx % 4 == 0 :
4274
4293
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
4275
4294
# it fail to work properly for stateless mode. It doesn't know
4276
4295
# how to correctly reset the board state to what is given in the
@@ -4287,7 +4306,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
4287
4306
4288
4307
action_idx = torch .randint (0 , all_actions .shape [0 ], ()).item ()
4289
4308
chosen_action = all_actions [action_idx ]
4290
- td = env .step (td .update (chosen_action ))["next" ]
4309
+ td_new = env .step (td .update (chosen_action ).clone ())
4310
+ assert (td == td_new .exclude ("next" )).all ()
4311
+ td = td_new ["next" ]
4291
4312
4292
4313
if td ["done" ]:
4293
4314
td = env .reset ()
0 commit comments