Skip to content

Commit

Permalink
🐛 Fix sequential stream consuming source eager
Browse files Browse the repository at this point in the history
  • Loading branch information
garlontas committed Nov 25, 2024
1 parent fd70cad commit 6f4fdb1
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pystreamapi/__stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def concat(*streams: "BaseStream[_K]"):
:param streams: The streams to concatenate
:return: The concatenated stream
"""
return streams[0].__class__(itertools.chain(*list(streams)))
return streams[0].__class__(itertools.chain(*iter(streams)))

@staticmethod
def iterate(seed: _K, func: Callable[[_K], _K]) -> BaseStream[_K]:
Expand Down
37 changes: 35 additions & 2 deletions pystreamapi/_itertools/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pylint: disable=protected-access
from typing import Iterable

from pystreamapi._streams.error.__error import ErrorHandler, _sentinel


def dropwhile(predicate, iterable, handler: ErrorHandler=None):
def dropwhile(predicate, iterable, handler: ErrorHandler = None):
"""
Drop items from the iterable while predicate(item) is true.
Afterward, return every element until the iterable is exhausted.
Expand All @@ -22,7 +24,7 @@ def dropwhile(predicate, iterable, handler: ErrorHandler=None):
_initial_missing = object()


def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler=None):
def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler = None):
"""
Apply a function of two arguments cumulatively to the items of a sequence
or iterable, from left to right, to reduce the iterable to a single
Expand Down Expand Up @@ -51,3 +53,34 @@ def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler=N
value = function(value, element)

return value


def peek(iterable: Iterable, mapper):
for item in iterable:
mapper(item)
yield item

def distinct(iterable):
seen = set()
for item in iterable:
if item not in seen:
seen.add(item)
yield item

def limit(source: Iterable, max_nr: int):
iterator = iter(source)
for _ in range(max_nr):
try:
yield next(iterator)
except StopIteration:
break

def any_match(iterable):
for item in iterable:
if item:
return True
return False

def flat_map(iterable):
for stream in iterable:
yield from stream.to_list()
34 changes: 23 additions & 11 deletions pystreamapi/_streams/__base_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=protected-access
from __future__ import annotations

import functools
import itertools
from abc import abstractmethod
Expand All @@ -8,7 +9,7 @@
from typing import Iterable, Callable, Any, TypeVar, Iterator, TYPE_CHECKING, Union

from pystreamapi.__optional import Optional
from pystreamapi._itertools.tools import dropwhile
from pystreamapi._itertools.tools import dropwhile, distinct, limit, any_match
from pystreamapi._lazy.process import Process
from pystreamapi._lazy.queue import ProcessQueue
from pystreamapi._streams.error.__error import ErrorHandler
Expand Down Expand Up @@ -85,16 +86,19 @@ def _verify_open(self):
def __iter__(self) -> Iterator[K]:
return iter(self._source)

@classmethod
def concat(cls, *streams: "BaseStream[K]"):
def concat(self, *streams: "BaseStream[K]") -> BaseStream[K]:
"""
Creates a lazily concatenated stream whose elements are all the elements of the first stream
followed by all the elements of the other streams.
:param streams: The streams to concatenate
:return: The concatenated stream
"""
return cls(itertools.chain(*list(streams)))
self._queue.execute_all()
for stream in streams:
stream._queue.execute_all()
self._source = itertools.chain(self._source, *[stream._source for stream in streams])
return self

@_operation
def distinct(self) -> 'BaseStream[K]':
Expand All @@ -104,7 +108,7 @@ def distinct(self) -> 'BaseStream[K]':

def __distinct(self):
"""Removes duplicate elements from the stream."""
self._source = list(set(self._source))
self._source = distinct(self._source)

@_operation
def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':
Expand All @@ -119,7 +123,7 @@ def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':

def __drop_while(self, predicate: Callable[[Any], bool]):
"""Drops elements from the stream while the predicate is true."""
self._source = list(dropwhile(predicate, self._source, self))
self._source = dropwhile(predicate, self._source, self)

def error_level(self, level: ErrorLevel, *exceptions)\
-> Union["BaseStream[K]", NumericBaseStream]:
Expand Down Expand Up @@ -160,7 +164,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]':
return self

@abstractmethod
def _flat_map(self, predicate: Callable[[K], Iterable[_V]]):
def _flat_map(self, mapper: Callable[[K], Iterable[_V]]):
"""Implementation of flat_map. Should be implemented by subclasses."""

@_operation
Expand Down Expand Up @@ -196,7 +200,7 @@ def limit(self, max_size: int) -> 'BaseStream[K]':

def __limit(self, max_size: int):
"""Limits the stream to the first n elements."""
self._source = itertools.islice(self._source, max_size)
self._source = limit(self._source, max_size)

@_operation
def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]':
Expand Down Expand Up @@ -283,6 +287,7 @@ def reversed(self) -> 'BaseStream[K]':
"""
Returns a stream consisting of the elements of this stream, with their order being
reversed.
This does not work on infinite generators.
"""
self._queue.append(Process(self.__reversed))
return self
Expand Down Expand Up @@ -314,7 +319,7 @@ def skip(self, n: int) -> 'BaseStream[K]':

def __skip(self, n: int):
"""Skips the first n elements of the stream."""
self._source = self._source[n:]
self._source = itertools.islice(self._source, n, None)

@_operation
def sorted(self, comparator: Callable[[K], int] = None) -> 'BaseStream[K]':
Expand Down Expand Up @@ -345,7 +350,7 @@ def take_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]':

def __take_while(self, predicate: Callable[[Any], bool]):
"""Takes elements from the stream while the predicate is true."""
self._source = list(itertools.takewhile(predicate, self._source))
self._source = itertools.takewhile(predicate, self._source)

@abstractmethod
@terminal
Expand All @@ -363,7 +368,12 @@ def any_match(self, predicate: Callable[[K], bool]):
:param predicate: The callable predicate
"""
return any(self._itr(self._source, predicate))
def _one_wrapper(iterable, mapper):
for i in iterable:
yield self._one(mapper, item=i)

self._source = _one_wrapper(self._source, predicate)
return any_match(self._source)

@terminal
def count(self):
Expand Down Expand Up @@ -413,13 +423,15 @@ def none_match(self, predicate: Callable[[K], bool]):
@terminal
def min(self):
"""Returns the minimum element of this stream."""
self._source = list(self._source)
if len(self._source) > 0:
return Optional.of(min(self._source))
return Optional.empty()

@terminal
def max(self):
"""Returns the maximum element of this stream."""
self._source = list(self._source)
if len(self._source) > 0:
return Optional.of(max(self._source))
return Optional.empty()
Expand Down
4 changes: 2 additions & 2 deletions pystreamapi/_streams/__parallel_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def find_any(self):
return Optional.of(self._source[0])
return Optional.empty()

def _flat_map(self, predicate: Callable[[Any], stream.BaseStream]):
def _flat_map(self, mapper: Callable[[Any], stream.BaseStream]):
new_src = []
for element in Parallel(n_jobs=-1, prefer="threads", handler=self)(
delayed(self.__mapper(predicate))(element) for element in self._source):
delayed(self.__mapper(mapper))(element) for element in self._source):
new_src.extend(element.to_list())
self._source = new_src

Expand Down
21 changes: 10 additions & 11 deletions pystreamapi/_streams/__sequential_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pystreamapi._streams.__base_stream as stream
from pystreamapi.__optional import Optional
from pystreamapi._itertools.tools import reduce, flat_map, peek
from pystreamapi._streams.error.__error import _sentinel
from pystreamapi._itertools.tools import reduce

_identity_missing = object()

Expand All @@ -21,15 +21,13 @@ def _filter(self, predicate: Callable[[Any], bool]):

@stream.terminal
def find_any(self):
if len(self._source) > 0:
return Optional.of(self._source[0])
return Optional.empty()
try:
return Optional.of(next(iter(self._source)))
except StopIteration:
return Optional.empty()

def _flat_map(self, predicate: Callable[[Any], stream.BaseStream]):
new_src = []
for element in self._itr(self._source, mapper=predicate):
new_src.extend(element.to_list())
self._source = new_src
def _flat_map(self, mapper: Callable[[Any], stream.BaseStream]):
self._source = flat_map(self._itr(self._source, mapper=mapper))

def _group_to_dict(self, key_mapper: Callable[[Any], Any]):
groups = defaultdict(list)
Expand All @@ -43,13 +41,14 @@ def _group_to_dict(self, key_mapper: Callable[[Any], Any]):

@stream.terminal
def for_each(self, action: Callable):
self._peek(action)
for item in self._source:
self._one(mapper=action, item=item)

def _map(self, mapper: Callable[[Any], Any]):
self._source = self._itr(self._source, mapper=mapper)

def _peek(self, action: Callable):
self._itr(self._source, mapper=action)
self._source = peek(self._source, lambda x: self._one(mapper=action, item=x))

@stream.terminal
def reduce(self, predicate: Callable, identity=_identity_missing, depends_on_state=False):
Expand Down
7 changes: 3 additions & 4 deletions pystreamapi/_streams/error/__error.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from typing import Iterable

from pystreamapi._streams.error.__levels import ErrorLevel
from pystreamapi._streams.error.__sentinel import Sentinel
Expand Down Expand Up @@ -37,20 +38,18 @@ def _get_error_level(self):
"""Get the error level"""
return self.__error_level

def _itr(self, src, mapper=nothing, condition=true_condition) -> list:
def _itr(self, src, mapper=nothing, condition=true_condition) -> Iterable:
"""Iterate over the source and apply the mapper and condition"""
new_src = []
for i in src:
try:
if condition(i):
new_src.append(mapper(i))
yield mapper(i)
except self.__exceptions_to_ignore as e:
if self.__error_level == ErrorLevel.RAISE:
raise e
if self.__error_level == ErrorLevel.IGNORE:
continue
self.__log(e)
return new_src

def _one(self, mapper=nothing, condition=true_condition, item=None):
"""
Expand Down
26 changes: 13 additions & 13 deletions tests/_streams/error/test_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,42 @@ def setUp(self) -> None:

def test_iterate_raise(self):
self.handler._error_level(ErrorLevel.RAISE)
self.assertRaises(ValueError, lambda: self.handler._itr([1, 2, 3, 4, 5, "a"], int))
self.assertRaises(ValueError, lambda: list(self.handler._itr([1, 2, 3, 4, 5, "a"], int)))

def test_iterate_raise_with_condition(self):
self.handler._error_level(ErrorLevel.RAISE)
self.assertRaises(ValueError, lambda: self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != ""))
self.assertRaises(ValueError, lambda: list(self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != "")))

def test_iterate_ignore(self):
self.handler._error_level(ErrorLevel.IGNORE)
self.assertEqual(self.handler._itr([1, 2, 3, 4, 5, "a"], int), [1, 2, 3, 4, 5])
self.assertEqual(list(self.handler._itr([1, 2, 3, 4, 5, "a"], int)), [1, 2, 3, 4, 5])

def test_iterate_ignore_with_condition(self):
self.handler._error_level(ErrorLevel.IGNORE)
self.assertEqual(self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != ""), [1, 2, 3, 4, 5])
self.assertEqual(list(self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != "")), [1, 2, 3, 4, 5])


def test_iterate_ignore_specific_exceptions(self):
self.handler._error_level(ErrorLevel.IGNORE, ValueError, AttributeError)
self.assertEqual(self.handler._itr(
["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split()), [["b"], ["a"]])
self.assertEqual(list(self.handler._itr(
["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split())), [["b"], ["a"]])


def test_iterate_ignore_specific_exception_raise_another(self):
self.handler._error_level(ErrorLevel.IGNORE, ValueError)
self.assertRaises(AttributeError, lambda: self.handler._itr(
["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split()))
self.assertRaises(AttributeError, lambda: list(self.handler._itr(
["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split())))

def test_iterate_warn(self):
self.handler._error_level(ErrorLevel.WARN)
self.assertEqual(self.handler._itr([1, 2, 3, 4, 5, "a"], int), [1, 2, 3, 4, 5])
self.assertEqual(list(self.handler._itr([1, 2, 3, 4, 5, "a"], int)), [1, 2, 3, 4, 5])

def test_iterate_warn_with_condition(self):
self.handler._error_level(ErrorLevel.WARN)
self.assertEqual(self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != ""), [1, 2, 3, 4, 5])
self.assertEqual(list(self.handler._itr(
[1, 2, 3, 4, 5, "a"], int, lambda x: x != "")), [1, 2, 3, 4, 5])

def test_one_raise(self):
self.handler._error_level(ErrorLevel.RAISE)
Expand Down
Loading

0 comments on commit 6f4fdb1

Please sign in to comment.