Skip to content

Commit

Permalink
Add handlers explicitly (#3285)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Oct 24, 2023
1 parent ffc5e24 commit c605f41
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 73 deletions.
6 changes: 3 additions & 3 deletions pyro/contrib/autoname/autoname.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def _pyro_genname(msg):
msg["stop"] = True


_handler_name, _handler = _make_handler(AutonameMessenger)
_handler.__module__ = __name__
locals()[_handler_name] = _handler
@_make_handler(AutonameMessenger, __name__)
def autoname(fn=None, name=None):
...


@singledispatch
Expand Down
57 changes: 43 additions & 14 deletions pyro/contrib/funsor/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,46 @@
from .replay_messenger import ReplayMessenger
from .trace_messenger import TraceMessenger

_msngrs = [
EnumMessenger,
MarkovMessenger,
NamedMessenger,
PlateMessenger,
ReplayMessenger,
TraceMessenger,
VectorizedMarkovMessenger,
]

for _msngr_cls in _msngrs:
_handler_name, _handler = _make_handler(_msngr_cls)
_handler.__module__ = __name__
locals()[_handler_name] = _handler

@_make_handler(EnumMessenger, __name__)
def enum(fn=None, first_available_dim=None):
...


@_make_handler(MarkovMessenger, __name__)
def markov(fn=None, history=1, keep=False):
...


@_make_handler(NamedMessenger, __name__)
def named(fn=None, first_available_dim=None):
...


@_make_handler(PlateMessenger, __name__)
def plate(
fn=None,
name=None,
size=None,
subsample_size=None,
subsample=None,
dim=None,
use_cuda=None,
device=None,
):
...


@_make_handler(ReplayMessenger, __name__)
def replay(fn=None, trace=None, params=None):
...


@_make_handler(TraceMessenger, __name__)
def trace(fn=None, graph_type=None, param_only=None, pack_online=True):
...


@_make_handler(VectorizedMarkovMessenger, __name__)
def vectorized_markov(fn=None, name=None, size=None, dim=None, history=1):
...
173 changes: 117 additions & 56 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@

import collections
import functools
import re

from pyro.poutine import util

Expand Down Expand Up @@ -79,68 +78,130 @@
# Begin primitive operations
############################################

_msngrs = [
BlockMessenger,
BroadcastMessenger,
CollapseMessenger,
ConditionMessenger,
DoMessenger,
EnumMessenger,
EscapeMessenger,
InferConfigMessenger,
LiftMessenger,
MaskMessenger,
ReparamMessenger,
ReplayMessenger,
ScaleMessenger,
SeedMessenger,
TraceMessenger,
UnconditionMessenger,
SubstituteMessenger,
]


def _make_handler(msngr_cls):
_re1 = re.compile("(.)([A-Z][a-z]+)")
_re2 = re.compile("([a-z0-9])([A-Z])")

def handler(fn=None, *args, **kwargs):
if fn is not None and not (
callable(fn) or isinstance(fn, collections.abc.Iterable)
):
raise ValueError(
"{} is not callable, did you mean to pass it as a keyword arg?".format(
fn

def _make_handler(msngr_cls, module=None):
def handler_decorator(func):
def handler(fn=None, *args, **kwargs):
if fn is not None and not (
callable(fn) or isinstance(fn, collections.abc.Iterable)
):
raise ValueError(
f"{fn} is not callable, did you mean to pass it as a keyword arg?"
)
msngr = msngr_cls(*args, **kwargs)
return (
functools.update_wrapper(msngr(fn), fn, updated=())
if fn is not None
else msngr
)
msngr = msngr_cls(*args, **kwargs)
return (
functools.update_wrapper(msngr(fn), fn, updated=())
if fn is not None
else msngr
)

# handler names from messenger names: strip Messenger suffix, convert CamelCase to snake_case
handler_name = _re2.sub(
r"\1_\2", _re1.sub(r"\1_\2", msngr_cls.__name__.split("Messenger")[0])
).lower()
handler.__doc__ = (
"""Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format(
handler_name + "_messenger", msngr_cls.__name__
handler.__doc__ = (
"""Convenient wrapper of :class:`~pyro.poutine.{}.{}` \n\n""".format(
func.__name__ + "_messenger", msngr_cls.__name__
)
+ (msngr_cls.__doc__ if msngr_cls.__doc__ else "")
)
+ (msngr_cls.__doc__ if msngr_cls.__doc__ else "")
)
handler.__name__ = handler_name
return handler_name, handler
handler.__name__ = func.__name__
if module is not None:
handler.__module__ = module
return handler

return handler_decorator


@_make_handler(BlockMessenger)
def block(
fn=None,
hide_fn=None,
expose_fn=None,
hide_all=True,
expose_all=False,
hide=None,
expose=None,
hide_types=None,
expose_types=None,
):
...


@_make_handler(BroadcastMessenger)
def broadcast(fn=None):
...


@_make_handler(CollapseMessenger)
def collapse(fn=None, *args, **kwargs):
...


@_make_handler(ConditionMessenger)
def condition(fn, data):
...


@_make_handler(DoMessenger)
def do(fn, data):
...


@_make_handler(EnumMessenger)
def enum(fn=None, first_available_dim=None):
...


@_make_handler(EscapeMessenger)
def escape(fn, escape_fn):
...


@_make_handler(InferConfigMessenger)
def infer_config(fn, config_fn):
...


@_make_handler(LiftMessenger)
def lift(fn, prior):
...


@_make_handler(MaskMessenger)
def mask(fn, mask):
...


@_make_handler(ReparamMessenger)
def reparam(fn, config):
...


@_make_handler(ReplayMessenger)
def replay(fn=None, trace=None, params=None):
...


@_make_handler(ScaleMessenger)
def scale(fn, scale):
...


@_make_handler(SeedMessenger)
def seed(fn, rng_seed):
...


@_make_handler(TraceMessenger)
def trace(fn=None, graph_type=None, param_only=None):
...


@_make_handler(UnconditionMessenger)
def uncondition(fn=None):
...

trace = None # flake8
escape = None # flake8

for _msngr_cls in _msngrs:
_handler_name, _handler = _make_handler(_msngr_cls)
_handler.__module__ = __name__
locals()[_handler_name] = _handler
@_make_handler(SubstituteMessenger)
def substitute(fn, data):
...


#########################################
Expand Down

0 comments on commit c605f41

Please sign in to comment.