From 795b6047ae14bb4d1e199fc2d3f36fc6e98b7592 Mon Sep 17 00:00:00 2001 From: Jake Simones Date: Fri, 13 Sep 2019 16:20:54 -0500 Subject: [PATCH] Add unfold --- doc/source/api.rst | 2 ++ toolz/curried/__init__.py | 2 ++ toolz/functoolz.py | 54 ++++++++++++++++++++++++++++++++++- toolz/tests/test_functoolz.py | 17 ++++++++++- 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/doc/source/api.rst b/doc/source/api.rst index 86ac4c0b..7abdb468 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -71,6 +71,8 @@ Functoolz pipe thread_first thread_last + unfold + unfold_ Dicttoolz --------- diff --git a/toolz/curried/__init__.py b/toolz/curried/__init__.py index 356eddbd..d04bc2eb 100644 --- a/toolz/curried/__init__.py +++ b/toolz/curried/__init__.py @@ -98,6 +98,8 @@ update_in = toolz.curry(toolz.update_in) valfilter = toolz.curry(toolz.valfilter) valmap = toolz.curry(toolz.valmap) +unfold = toolz.curry(toolz.unfold) +unfold_ = toolz.curry(toolz.unfold_) del exceptions del toolz diff --git a/toolz/functoolz.py b/toolz/functoolz.py index 01d3857a..3ab614bb 100644 --- a/toolz/functoolz.py +++ b/toolz/functoolz.py @@ -12,7 +12,7 @@ __all__ = ('identity', 'apply', 'thread_first', 'thread_last', 'memoize', 'compose', 'compose_left', 'pipe', 'complement', 'juxt', 'do', - 'curry', 'flip', 'excepts') + 'curry', 'flip', 'excepts', 'unfold', 'unfold_') def identity(x): @@ -825,6 +825,58 @@ def __name__(self): return 'excepting' +def unfold(func, x): + """ Generate values from a seed value + + Each iteration, the generator yields ``func(x)[0]`` and evaluates + ``func(x)[1]`` to determine the next ``x`` value. Iteration proceeds as + long as ``func(x)`` is not None. + + >>> def doubles(x): + ... if x > 10: + ... return None + ... else: + ... return (x * 2, x + 1) + ... + >>> list(unfold(doubles, 1)) + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + + If ``x`` has type ``A`` and the generator yields values of type ``B``, + then ``func`` has type ``Callable[[A], Optional[Tuple[B, A]]]``. + + """ + while True: + t = func(x) + if t is None: + break + else: + yield t[0] + x = t[1] + + +def unfold_(predicate, func, succ, x): + """ Alternative formulation of unfold + + Each iteration, the generator yields ``func(x)`` and evaluates + ``succ(x)`` to determine the next ``x`` value. Iteration proceeds as long + as ``predicate(x)`` is True. + + >>> lte10 = lambda x: x <= 10 + >>> double = lambda x: x * 2 + >>> inc = lambda x: x + 1 + >>> list(unfold_(lte10, double, inc, 1)) + [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + + If ``x`` has type ``A`` and the generator yields values of type ``B``, + then ``predicate`` has type ``Callable[[A], bool]``, ``func`` has type + ``Callable[[A], B]``, and ``succ`` has type ``Callable[[A], A]``. + + """ + while predicate(x): + yield func(x) + x = succ(x) + + if PY3: # pragma: py2 no cover def _check_sigspec(sigspec, func, builtin_func, *builtin_args): if sigspec is None: diff --git a/toolz/tests/test_functoolz.py b/toolz/tests/test_functoolz.py index 4938df8a..154771ff 100644 --- a/toolz/tests/test_functoolz.py +++ b/toolz/tests/test_functoolz.py @@ -2,7 +2,7 @@ import toolz from toolz.functoolz import (thread_first, thread_last, memoize, curry, compose, compose_left, pipe, complement, do, juxt, - flip, excepts, apply) + flip, excepts, apply, unfold, unfold_) from toolz.compatibility import PY3 from operator import add, mul, itemgetter from toolz.utils import raises @@ -796,3 +796,18 @@ def raise_(a): excepting = excepts(object(), object(), object()) assert excepting.__name__ == 'excepting' assert excepting.__doc__ == excepts.__doc__ + + +def test_unfold(): + expected = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20] + + def doubles(x): + if x > 10: + return None + else: + return (x * 2, x + 1) + assert list(unfold(doubles, 1)) == expected + + def lte10(x): + return x <= 10 + assert list(unfold_(lte10, double, inc, 1)) == expected