[BUG] SafeProbabilisticModule constructor missing log_prob_keys
argument
#2731
Labels
bug
Something isn't working
log_prob_keys
argument
#2731
Describe the bug
When disabling logprob aggregation for a probabilistic actor you are supposed to pass a sequence of
log_prob_keys
as a parameter instead of a singlelog_prob_key
. However, the parameter is not properly passed up the class hierarchy in the constructor.This happens because the kwargs are passed to
SafeProbabilisticModule
properly the constructor does not expect alog_prob_keys
argument resulting in aTypeError
.To Reproduce
Disable logprob aggregation by using
set_composite_lp_aggregate(...).set()
, pass thereturn_log_prob
argument to theProbabilisticActor
constructor and provide a sequence of logprob keys via thelog_prob_keys
argument.Expected behavior
The
TypeError
does not occur.System info
torchrl + tensordict from the current main branches.
Reason and Possible fixes
To fix it should be enough to add the expected argument to the
__init__
method ofSafeProbabilisticModule
and also pass it to the superclass constructor ofProbabilisticTensorDictModule
which properly defines the argument already.Checklist
The text was updated successfully, but these errors were encountered: