Skip to content

Commit bb44552

Browse files
committed
Checking consistency of arguments in @Batch calls
1 parent 0471a9c commit bb44552

File tree

3 files changed

+95
-37
lines changed

3 files changed

+95
-37
lines changed

dike/__init__.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,35 @@ async def limited_call(*args, **kwargs):
106106

107107
# Deactivate mccabe's complexity warnings which doesn't like closures
108108
# flake8: noqa: C901
109-
def batch(*, target_batch_size: int, max_waiting_time: float, max_processing_time: float = 10.0):
109+
def batch(
110+
*,
111+
target_batch_size: int,
112+
max_waiting_time: float,
113+
max_processing_time: float = 10.0,
114+
argument_type: str = "list",
115+
):
110116
"""@batch is a decorator to cumulate function calls and process them in batches.
111117
Not thread-safe.
112118
113119
Args:
114120
target_batch_size: As soon as the collected function arguments reach target_batch_size,
115121
the wrapped function is called and the results are returned. Note that the function
116122
may also be called with longer arguments than target_batch_size.
117-
max_waiting_time: Maximum waiting time before calling the underlying function although
118-
the target_batch_size hasn't been reached.
119-
max_processing_time: Maximum time for the processing itself (without waiting) before an
120-
asyncio.TimeoutError is raised. Note: It is strongly advised to set a reasonably
121-
strict timeout here in order not to create starving tasks which never finish in case
122-
something is wrong with the backend call.
123+
max_waiting_time: Maximum waiting time in seconds before calling the underlying function
124+
although the target_batch_size hasn't been reached.
125+
max_processing_time: Maximum time in seconds for the processing itself (without waiting)
126+
before an asyncio.TimeoutError is raised. Note: It is strongly advised to set a
127+
reasonably strict timeout here in order not to create starving tasks which never finish
128+
in case something is wrong with the backend call.
129+
argument_type: The type of function argument used for batching. One of "list" or "numpy".
130+
Per default "list" is used, i.e. it is assumed that the input arguments to the
131+
wrapped functions are lists which can be concatenated. If set to "numpy" the arguments
132+
are assumed to be numpy arrays which can be concatenated by numpy.concatenate()
133+
along axis 0.
123134
124135
Raises:
125136
ValueError: If the arguments target_batch_size or max_waiting time are not >= 0.
137+
ValueError: When calling the function with incorrect or inconsistent arguments.
126138
asyncio.TimeoutError: Is raised when calling the wrapped function takes longer than
127139
max_processing_time
128140
@@ -137,7 +149,9 @@ def batch(*, target_batch_size: int, max_waiting_time: float, max_processing_tim
137149
function in order to avoid race conditions.
138150
- The return value of the wrapped function must be a single iterable.
139151
- All calls to the underlying function need to have the same number of positional arguments and
140-
the same keyword arguments.
152+
the same keyword arguments. It also isn't possible to mix the two ways to pass an argument.
153+
The same argument always has to be passed either as keyword argument or as positional
154+
argument.
141155
142156
Example:
143157
>>> import asyncio
@@ -197,16 +211,20 @@ def add_args_to_queue(args, kwargs):
197211
"""Add a new argument vector to the queue and return result indices"""
198212
nonlocal queue, n_rows_in_queue
199213

200-
queue.append((args, kwargs))
201-
offset = n_rows_in_queue
214+
if queue and (len(args) != len(queue[0][0]) or kwargs.keys() != queue[0][1].keys()):
215+
raise ValueError("Inconsistent use of positional and keyword arguments")
216+
n_rows_call = 0
202217
if args:
203-
n_rows_in_queue += len(args[0])
218+
n_rows_call = len(args[0])
204219
elif kwargs:
205220
for v in kwargs.values():
206-
n_rows_in_queue += len(v)
207-
break
208-
else:
221+
n_rows_call = len(v)
222+
break # We only need one arbitrary keyword argument
223+
if n_rows_call == 0:
209224
raise ValueError("Function called with empty collections as arguments")
225+
queue.append((args, kwargs))
226+
offset = n_rows_in_queue
227+
n_rows_in_queue += n_rows_call
210228
return offset, n_rows_in_queue
211229

212230
async def wait_for_calculation(batch_no_to_calculate):
@@ -237,7 +255,6 @@ async def calculate(batch_no_to_calculate):
237255
results_ready[batch_no_to_calculate] = n_results
238256
result_events[batch_no_to_calculate].set()
239257

240-
241258
def pop_args_from_queue():
242259
nonlocal batch_no, queue, n_rows_in_queue
243260

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ plugins:
3030
- mkapi:
3131
src_dirs:
3232
- .
33+
filters: [short, strict]

tests/test_batch.py

+62-22
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,100 @@
22
import asyncio
33
import random
44

5+
import numpy as np
56
import pytest
67

78
import dike
89

910

11+
def exceptions_equal(exception1, exception2):
12+
"""Returns True if the exceptions have the same type and message"""
13+
return type(exception1) == type(exception2) and str(exception1) == str(exception2)
14+
15+
1016
async def raise_error(message):
1117
raise RuntimeError(message)
1218

1319

14-
def test_single_items_batchsize_reached():
20+
@pytest.mark.parametrize("argument_type", [list, np.array])
21+
def test_single_items_batchsize_reached(argument_type):
1522
@dike.batch(target_batch_size=3, max_waiting_time=10)
1623
async def f(arg1, arg2):
17-
assert arg1 == [0, 1, 2]
18-
assert arg2 == ["a", "b", "c"]
19-
return [10, 11, 12]
24+
assert arg1 == argument_type([0, 1, 2])
25+
assert arg2 == argument_type(["a", "b", "c"])
26+
return argument_type([10, 11, 12])
2027

2128
async def run_test():
2229
result = await asyncio.wait_for(
2330
asyncio.gather(
24-
f([0], ["a"]),
25-
f([1], ["b"]),
26-
f([2], ["c"]),
31+
f(argument_type([0]), argument_type(["a"])),
32+
f(argument_type([1]), argument_type(["b"])),
33+
f(argument_type([2]), argument_type(["c"])),
2734
),
2835
timeout=1.0,
2936
)
3037

31-
assert result == [[10], [11], [12]]
38+
assert result == [argument_type([10]), argument_type([11]), argument_type([12])]
3239

3340
asyncio.run(run_test())
3441

3542

36-
def test_single_items_kwargs_batchsize_reached():
43+
@pytest.mark.parametrize("argument_type", [list, np.array])
44+
def test_single_items_kwargs_batchsize_reached(argument_type):
3745
@dike.batch(target_batch_size=3, max_waiting_time=10)
3846
async def f(arg1, arg2):
39-
assert arg1 == [0, 1, 2]
40-
assert arg2 == ["a", "b", "c"]
41-
return [10, 11, 12]
47+
assert arg1 == argument_type([0, 1, 2])
48+
assert arg2 == argument_type(["a", "b", "c"])
49+
return argument_type([10, 11, 12])
4250

4351
async def run_test():
4452
result = await asyncio.wait_for(
4553
asyncio.gather(
46-
f(arg1=[0], arg2=["a"]),
47-
f(arg1=[1], arg2=["b"]),
48-
f(arg2=["c"], arg1=[2]),
54+
f(arg2=argument_type(["a"]), arg1=argument_type([0])),
55+
f(arg2=argument_type(["b"]), arg1=argument_type([1])),
56+
f(arg1=argument_type([2]), arg2=argument_type(["c"])),
57+
# f(arg2=argument_type(["c"]), arg1=argument_type([2])),
4958
),
5059
timeout=1.0,
5160
)
5261

53-
assert result == [[10], [11], [12]]
62+
assert result == [argument_type([10]), argument_type([11]), argument_type([12])]
63+
64+
asyncio.run(run_test())
65+
66+
67+
@pytest.mark.parametrize("argument_type", [list, np.array])
68+
def test_single_items_mixed_kwargs_raises_value_error(argument_type):
69+
@dike.batch(target_batch_size=3, max_waiting_time=0.01)
70+
async def f(arg1, arg2):
71+
assert arg1 == argument_type([0, 1])
72+
assert arg2 == argument_type(["a", "b"])
73+
return argument_type([10, 11])
74+
75+
async def run_test():
76+
result = await asyncio.wait_for(
77+
asyncio.gather(
78+
f(argument_type([0]), argument_type(["a"])),
79+
f(argument_type([1]), argument_type(["b"])),
80+
f(arg2=argument_type(["c"]), arg1=argument_type([2])),
81+
f(argument_type([1])),
82+
f(argument_type([]), argument_type([])),
83+
return_exceptions=True
84+
),
85+
timeout=1.0,
86+
)
87+
88+
assert result[0] == argument_type([10])
89+
assert result[1] == argument_type([11])
90+
assert exceptions_equal(
91+
result[2], ValueError("Inconsistent use of positional and keyword arguments")
92+
)
93+
assert exceptions_equal(
94+
result[3], ValueError("Inconsistent use of positional and keyword arguments")
95+
)
96+
assert exceptions_equal(
97+
result[4], ValueError("Function called with empty collections as arguments")
98+
)
5499

55100
asyncio.run(run_test())
56101

@@ -193,12 +238,7 @@ async def f(arg1, arg2):
193238

194239
async def run_test():
195240
results = await asyncio.wait_for(
196-
asyncio.gather(
197-
f([0], ["a"]),
198-
f([1], ["b"]),
199-
f([2], ["c"]),
200-
return_exceptions=True
201-
),
241+
asyncio.gather(f([0], ["a"]), f([1], ["b"]), f([2], ["c"]), return_exceptions=True),
202242
timeout=1.0,
203243
)
204244
for r in results:

0 commit comments

Comments
 (0)