Skip to content

Commit

Permalink
Improve the documentation of compute_episodic_return in base policy. (t…
Browse files Browse the repository at this point in the history
  • Loading branch information
bordeauxred authored Apr 30, 2024
1 parent a65920f commit 61426ac
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,8 @@ postfix
backend
rliable
hl
v_s
v_s_
obs
obs_next

20 changes: 12 additions & 8 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def forward(
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
* ``act`` a numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
* ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
Expand Down Expand Up @@ -556,6 +556,9 @@ def compute_episodic_return(
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
for estimating returns.
Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the
discounted return-to-go/ Monte-Carlo return.
:param batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
Expand All @@ -565,10 +568,11 @@ def compute_episodic_return(
to buffer[indices].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
If None, it will be set to an array of 0.
:param v_s: the value function of all current states :math:`V(s)`.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param v_s: the value function of all current states :math:`V(s)`. If None,
it is set based upon v_s_ rolled by 1.
:param gamma: the discount factor, should be in [0, 1].
:param gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
should be in [0, 1].
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
"""
Expand Down Expand Up @@ -612,10 +616,10 @@ def compute_nstep_return(
:param indices: tell batch's location in buffer
:param function target_q_fn: a function which compute target Q value
of "obs_next" given data buffer and wanted indices.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param gamma: the discount factor, should be in [0, 1].
:param n_step: the number of estimation step, should be an int greater
than 0. Default to 1.
:param rew_norm: normalize the reward to Normal(0, 1), Default to False.
than 0.
:param rew_norm: normalize the reward to Normal(0, 1).
TODO: passing True is not supported and will cause an error!
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with the same shape as target_q_fn's return tensor.
Expand Down

0 comments on commit 61426ac

Please sign in to comment.