Skip to content

[BUG] SafeProbabilisticModule constructor missing log_prob_keys argument #2731

Open
@rerz

Description

@rerz

Describe the bug

When disabling logprob aggregation for a probabilistic actor you are supposed to pass a sequence of log_prob_keys as a parameter instead of a single log_prob_key. However, the parameter is not properly passed up the class hierarchy in the constructor.
This happens because the kwargs are passed to SafeProbabilisticModule properly the constructor does not expect a log_prob_keys argument resulting in a TypeError.

TypeError: SafeProbabilisticModule.__init__() got an unexpected keyword argument 'log_prob_keys'

To Reproduce

Disable logprob aggregation by using set_composite_lp_aggregate(...).set(), pass the return_log_prob argument to the ProbabilisticActor constructor and provide a sequence of logprob keys via the log_prob_keys argument.

Expected behavior

The TypeError does not occur.

System info

torchrl + tensordict from the current main branches.

Reason and Possible fixes

To fix it should be enough to add the expected argument to the __init__ method of SafeProbabilisticModule and also pass it to the superclass constructor of ProbabilisticTensorDictModule which properly defines the argument already.

class SafeProbabilisticModule(ProbabilisticTensorDictModule):
    def __init__(
        self,
        in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
        out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
        spec: Optional[TensorSpec] = None,
        safe: bool = False,
        default_interaction_type: str = InteractionType.DETERMINISTIC,
        distribution_class: Type = Delta,
        distribution_kwargs: Optional[dict] = None,
        return_log_prob: bool = False,
        log_prob_key: NestedKey | None = None,
        log_prob_keys: List[NestedKey] | None = None,   <----- here
        cache_dist: bool = False,
        n_empirical_estimate: int = 1000,
    ):
        super().__init__(
            in_keys=in_keys,
            out_keys=out_keys,
            default_interaction_type=default_interaction_type,
            distribution_class=distribution_class,
            distribution_kwargs=distribution_kwargs,
            return_log_prob=return_log_prob,
            log_prob_key=log_prob_key,
            log_prob_keys=log_prob_keys, <----- here
            cache_dist=cache_dist,
            n_empirical_estimate=n_empirical_estimate,
        )

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions