-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to collect same number of episodes in each collector env #1046
Add option to collect same number of episodes in each collector env #1046
Conversation
…each env in collector
@@ -476,6 +526,9 @@ def collect( | |||
|
|||
:return: A dataclass object | |||
""" | |||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should we allow a param if it can't actually be changed?
…t recording worker wise returns!
self._lengths[buffer_id] = len(self.buffers[buffer_id]) | ||
ep_last_idxs = np.array(ep_last_idxs) | ||
ep_add_at_idxs = np.array(ep_add_at_idxs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the semantics of this variable? The docstring calls it the current index. If it's not the ep_last_idxs, what does it mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index at which the transition is added to the buffer.
As there is ep_start_idx that indicates the first transition of the current episode, ep_last_idx should be the index of the last transition in the episode. Whenever the current transition does not contain done, this is not the last index of the episode (as it continues) but the index at which to add the current transition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking for discussion in pair programming
collect_time=collect_call_duration, | ||
collect_speed=step_count / collect_call_duration, | ||
returns=np.array(episode_returns), | ||
returns_stat=SequenceSummaryStats.from_sequence(episode_returns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move the declaration of this and lens_stats above the instantiation of the object, a bit easier to read
I'm also wondering whether this could be done in the CollectStatsBase itself in post_init, let's have a look together later
tianshou/data/collector.py
Outdated
def reset_env( | ||
self, | ||
gym_reset_kwargs: dict[str, Any] | None = None, | ||
set_obs_next_to_obs: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a non-trivial functionality, why is it needed and when would one want to use it? Pls add a docstring
|
||
if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): | ||
break | ||
if n_episode: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking this block for discussion:
- why gym_reset_kwargs used only in one place
- factor out to separate method?
- Avoid usage of np.where, instead use getitem on boolean array
lens_stat=SequenceSummaryStats.from_sequence(episode_lens) | ||
if len(episode_lens) > 0 | ||
else None, | ||
def sample_at_least_one_episode_per_worker_postprocessing_on_done_env( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't seem to sample anything, rather only filter.
Marking for discussion
self.data = self.data[mask] | ||
return ready_env_ids | ||
|
||
def sample_equal_episodes_per_worker_postprocessing_on_done_env( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
done=done, | ||
info=info, | ||
) | ||
if self.preprocess_fn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
marking for discussion - when is this option used? Is it dangerous to update data twice?
tianshou/data/collector.py
Outdated
except TypeError: # envpool's action space is not for per-env | ||
act_sample = [self._action_space.sample() for _ in ready_env_ids] | ||
act_sample = self.policy.map_action_inverse(act_sample) # type: ignore | ||
self.data.update(act=act_sample) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the method name I wouldn't expect it to update data. Generally, data is modified inplace all over the class, we should avoid it where possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, a much cleaner and more readable structure than before, but there are still some issues to address. Let's talk later today
Closing this, @bordeauxred will make a new one after #1063 is merged |
poe format
poe lint
andpoe type-check
poe test
(or a subset of them with
poe test-reduced
) ,and they passpoe doc-build