Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] SafeProbabilisticModule constructor missing log_prob_keys argument #2731

Open
3 tasks done
rerz opened this issue Jan 29, 2025 · 0 comments
Open
3 tasks done

[BUG] SafeProbabilisticModule constructor missing log_prob_keys argument #2731

rerz opened this issue Jan 29, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@rerz
Copy link

rerz commented Jan 29, 2025

Describe the bug

When disabling logprob aggregation for a probabilistic actor you are supposed to pass a sequence of log_prob_keys as a parameter instead of a single log_prob_key. However, the parameter is not properly passed up the class hierarchy in the constructor.
This happens because the kwargs are passed to SafeProbabilisticModule properly the constructor does not expect a log_prob_keys argument resulting in a TypeError.

TypeError: SafeProbabilisticModule.__init__() got an unexpected keyword argument 'log_prob_keys'

To Reproduce

Disable logprob aggregation by using set_composite_lp_aggregate(...).set(), pass the return_log_prob argument to the ProbabilisticActor constructor and provide a sequence of logprob keys via the log_prob_keys argument.

Expected behavior

The TypeError does not occur.

System info

torchrl + tensordict from the current main branches.

Reason and Possible fixes

To fix it should be enough to add the expected argument to the __init__ method of SafeProbabilisticModule and also pass it to the superclass constructor of ProbabilisticTensorDictModule which properly defines the argument already.

class SafeProbabilisticModule(ProbabilisticTensorDictModule):
    def __init__(
        self,
        in_keys: Union[NestedKey, List[NestedKey], Dict[str, NestedKey]],
        out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None,
        spec: Optional[TensorSpec] = None,
        safe: bool = False,
        default_interaction_type: str = InteractionType.DETERMINISTIC,
        distribution_class: Type = Delta,
        distribution_kwargs: Optional[dict] = None,
        return_log_prob: bool = False,
        log_prob_key: NestedKey | None = None,
        log_prob_keys: List[NestedKey] | None = None,   <----- here
        cache_dist: bool = False,
        n_empirical_estimate: int = 1000,
    ):
        super().__init__(
            in_keys=in_keys,
            out_keys=out_keys,
            default_interaction_type=default_interaction_type,
            distribution_class=distribution_class,
            distribution_kwargs=distribution_kwargs,
            return_log_prob=return_log_prob,
            log_prob_key=log_prob_key,
            log_prob_keys=log_prob_keys, <----- here
            cache_dist=cache_dist,
            n_empirical_estimate=n_empirical_estimate,
        )

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@rerz rerz added the bug Something isn't working label Jan 29, 2025
@rerz rerz changed the title [BUG] SafeProbabilisticModule constructor missing log_prob_keys argument [BUG] SafeProbabilisticModule constructor missing log_prob_keys argument Jan 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants