Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing of poutine.trace(model).get_trace() #3334

Merged
merged 2 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@
from pyro.poutine.lift_messenger import LiftMessenger
from pyro.poutine.markov_messenger import MarkovMessenger
from pyro.poutine.mask_messenger import MaskMessenger
from pyro.poutine.reparam_messenger import ReparamMessenger
from pyro.poutine.reparam_messenger import ReparamHandler, ReparamMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.runtime import NonlocalExit
from pyro.poutine.scale_messenger import ScaleMessenger
from pyro.poutine.seed_messenger import SeedMessenger
from pyro.poutine.substitute_messenger import SubstituteMessenger
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.trace_messenger import TraceHandler, TraceMessenger
from pyro.poutine.uncondition_messenger import UnconditionMessenger

if TYPE_CHECKING:
Expand Down Expand Up @@ -152,7 +152,7 @@ def block(

@overload
def block(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
hide_all: bool = True,
Expand Down Expand Up @@ -186,7 +186,7 @@ def broadcast(

@overload
def broadcast(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
) -> Callable[_P, _T]: ...


Expand All @@ -206,7 +206,7 @@ def collapse(

@overload
def collapse(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
*args: Any,
**kwargs: Any,
) -> Callable[_P, _T]: ...
Expand Down Expand Up @@ -269,7 +269,7 @@ def enum(

@overload
def enum(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
first_available_dim: Optional[int] = None,
) -> Callable[_P, _T]: ...

Expand Down Expand Up @@ -371,14 +371,14 @@ def reparam(
def reparam(
fn: Callable[_P, _T],
config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]],
) -> Callable[_P, _T]: ...
) -> ReparamHandler[_P, _T]: ...


@_make_handler(ReparamMessenger)
def reparam( # type: ignore[empty-body]
fn: Callable[_P, _T],
config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]],
) -> Union[ReparamMessenger, Callable[_P, _T]]: ...
) -> Union[ReparamMessenger, ReparamHandler[_P, _T]]: ...


@overload
Expand All @@ -391,7 +391,7 @@ def replay(

@overload
def replay(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
trace: Optional["Trace"] = None,
params: Optional[Dict[str, "torch.Tensor"]] = None,
) -> Callable[_P, _T]: ...
Expand Down Expand Up @@ -475,18 +475,18 @@ def trace(

@overload
def trace(
fn: Callable[_P, _T] = ...,
fn: Callable[_P, _T],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to remove the default to avoid the mypy error:

pyro/poutine/handlers.py:469: error: Overloaded function signatures 1 and 2 overlap with incompatible return types  [overload-overlap]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, seems like some weird behavior from mypy

graph_type: Optional[Literal["flat", "dense"]] = None,
param_only: Optional[bool] = None,
) -> Callable[_P, _T]: ...
) -> TraceHandler[_P, _T]: ...


@_make_handler(TraceMessenger)
def trace( # type: ignore[empty-body]
fn: Optional[Callable[_P, _T]] = None,
graph_type: Optional[Literal["flat", "dense"]] = None,
param_only: Optional[bool] = None,
) -> Union[TraceMessenger, Callable[_P, _T]]: ...
) -> Union[TraceMessenger, TraceHandler[_P, _T]]: ...


@overload
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/reparam_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.config = config
self._args_kwargs = None

def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
def __call__(self, fn: Callable[_P, _T]) -> "ReparamHandler[_P, _T]":
return ReparamHandler(self, fn)

def _pyro_sample(self, msg: "Message") -> None:
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/trace_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __exit__(self, *args, **kwargs) -> None:
identify_dense_edges(self.trace)
return super().__exit__(*args, **kwargs)

def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
def __call__(self, fn: Callable[_P, _T]) -> "TraceHandler[_P, _T]":
"""
TODO docs
"""
Expand Down
Loading