Description
Describe the bug
When disabling logprob aggregation for a probabilistic actor you are supposed to pass a sequence of log_prob_keys
as a parameter instead of a single log_prob_key
. However, the parameter is not properly passed up the class hierarchy in the constructor.
This happens because the kwargs are passed to SafeProbabilisticModule
properly the constructor does not expect a log_prob_keys
argument resulting in a TypeError
.
TypeError: SafeProbabilisticModule.__init__() got an unexpected keyword argument 'log_prob_keys'
To Reproduce
Disable logprob aggregation by using set_composite_lp_aggregate(...).set()
, pass the return_log_prob
argument to the ProbabilisticActor
constructor and provide a sequence of logprob keys via the log_prob_keys
argument.
Expected behavior
The TypeError
does not occur.
System info
torchrl + tensordict from the current main branches.
Reason and Possible fixes
To fix it should be enough to add the expected argument to the __init__
method of SafeProbabilisticModule
and also pass it to the superclass constructor of ProbabilisticTensorDictModule
which properly defines the argument already.
class SafeProbabilisticModule(ProbabilisticTensorDictModule):
def __init__(
self,
in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
spec: Optional[TensorSpec] = None,
safe: bool = False,
default_interaction_type: str = InteractionType.DETERMINISTIC,
distribution_class: Type = Delta,
distribution_kwargs: Optional[dict] = None,
return_log_prob: bool = False,
log_prob_key: NestedKey | None = None,
log_prob_keys: List[NestedKey] | None = None, <----- here
cache_dist: bool = False,
n_empirical_estimate: int = 1000,
):
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
default_interaction_type=default_interaction_type,
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=return_log_prob,
log_prob_key=log_prob_key,
log_prob_keys=log_prob_keys, <----- here
cache_dist=cache_dist,
n_empirical_estimate=n_empirical_estimate,
)
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)