Skip to content

Commit

Permalink
minor rename
Browse files Browse the repository at this point in the history
  • Loading branch information
bordeauxred committed Mar 7, 2024
1 parent e9a3278 commit c6a707e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def collect(
episode_start_indices: list[int] = []

# DIMENSION NAMING CONVENTION
# R - number ready env ids
# R - number ready env ids. Note that this might change when some envs
# get idle.
# O - dimension(s) of observations
# A - dimension(s) of actions
# H - dimension(s) of hidden state
Expand Down Expand Up @@ -341,7 +342,7 @@ def collect(
)

# TODO: cleanup the whole policy in batch thing
# update state / act / policy into cur_rollout_batch
# todo policy_R can also be none, check
policy_R = act_batch_R.get("policy", Batch())
if not isinstance(policy_R, Batch):
raise RuntimeError(
Expand All @@ -352,9 +353,9 @@ def collect(
policy_R.hidden_state = hidden_state_RH # save state into buffer
self.last_hidden_state_of_policy = hidden_state_RH

normalized_action = self.policy.map_action(act_RA)
normalized_action_R = self.policy.map_action(act_RA)
obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step(
normalized_action,
normalized_action_R,
ready_env_ids_R,
)
done_R = np.logical_or(terminated_R, truncated_R)
Expand All @@ -380,7 +381,7 @@ def collect(
time.sleep(render)

# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add(
rollout_batch,
buffer_ids=ready_env_ids_R,
)
Expand All @@ -392,9 +393,9 @@ def collect(
env_ind_local = np.where(done_R)[0]
env_ind_global = ready_env_ids_R[env_ind_local]
episode_count += len(env_ind_local)
episode_lens.extend(ep_len[env_ind_local])
episode_returns.extend(ep_rew[env_ind_local])
episode_start_indices.extend(ep_idx[env_ind_local])
episode_lens.extend(ep_len_R[env_ind_local])
episode_returns.extend(ep_rew_R[env_ind_local])
episode_start_indices.extend(ep_idx_R[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.

Expand Down

0 comments on commit c6a707e

Please sign in to comment.