From 61426acf0754305347e4c740f15003007a7ff8f6 Mon Sep 17 00:00:00 2001 From: bordeauxred <2robert.mueller@gmail.com> Date: Tue, 30 Apr 2024 14:40:16 +0200 Subject: [PATCH] Improve the documentation of compute_episodic_return in base policy. (#1130) --- docs/spelling_wordlist.txt | 5 +++++ tianshou/policy/base.py | 20 ++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 383429eb2..83de82356 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -266,3 +266,8 @@ postfix backend rliable hl +v_s +v_s_ +obs +obs_next + diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 77602a02b..1462ff4cc 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -301,9 +301,9 @@ def forward( :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: - * ``act`` an numpy.ndarray or a torch.Tensor, the action over \ + * ``act`` a numpy.ndarray or a torch.Tensor, the action over \ given batch data. - * ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \ + * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \ internal state of the policy, ``None`` as default. Other keys are user-defined. It depends on the algorithm. For example, @@ -556,6 +556,9 @@ def compute_episodic_return( advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` for estimating returns. + Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the + discounted return-to-go/ Monte-Carlo return. + :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be @@ -565,10 +568,11 @@ def compute_episodic_return( to buffer[indices]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. If None, it will be set to an array of 0. - :param v_s: the value function of all current states :math:`V(s)`. - :param gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param v_s: the value function of all current states :math:`V(s)`. If None, + it is set based upon v_s_ rolled by 1. + :param gamma: the discount factor, should be in [0, 1]. :param gae_lambda: the parameter for Generalized Advantage Estimation, - should be in [0, 1]. Default to 0.95. + should be in [0, 1]. :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ @@ -612,10 +616,10 @@ def compute_nstep_return( :param indices: tell batch's location in buffer :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. - :param gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param gamma: the discount factor, should be in [0, 1]. :param n_step: the number of estimation step, should be an int greater - than 0. Default to 1. - :param rew_norm: normalize the reward to Normal(0, 1), Default to False. + than 0. + :param rew_norm: normalize the reward to Normal(0, 1). TODO: passing True is not supported and will cause an error! :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor.