|
2 | 2 | import asyncio
|
3 | 3 | import random
|
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import pytest
|
6 | 7 |
|
7 | 8 | import dike
|
8 | 9 |
|
9 | 10 |
|
| 11 | +def exceptions_equal(exception1, exception2): |
| 12 | + """Returns True if the exceptions have the same type and message""" |
| 13 | + return type(exception1) == type(exception2) and str(exception1) == str(exception2) |
| 14 | + |
| 15 | + |
10 | 16 | async def raise_error(message):
|
11 | 17 | raise RuntimeError(message)
|
12 | 18 |
|
13 | 19 |
|
14 |
| -def test_single_items_batchsize_reached(): |
| 20 | +@pytest.mark.parametrize("argument_type", [list, np.array]) |
| 21 | +def test_single_items_batchsize_reached(argument_type): |
15 | 22 | @dike.batch(target_batch_size=3, max_waiting_time=10)
|
16 | 23 | async def f(arg1, arg2):
|
17 |
| - assert arg1 == [0, 1, 2] |
18 |
| - assert arg2 == ["a", "b", "c"] |
19 |
| - return [10, 11, 12] |
| 24 | + assert arg1 == argument_type([0, 1, 2]) |
| 25 | + assert arg2 == argument_type(["a", "b", "c"]) |
| 26 | + return argument_type([10, 11, 12]) |
20 | 27 |
|
21 | 28 | async def run_test():
|
22 | 29 | result = await asyncio.wait_for(
|
23 | 30 | asyncio.gather(
|
24 |
| - f([0], ["a"]), |
25 |
| - f([1], ["b"]), |
26 |
| - f([2], ["c"]), |
| 31 | + f(argument_type([0]), argument_type(["a"])), |
| 32 | + f(argument_type([1]), argument_type(["b"])), |
| 33 | + f(argument_type([2]), argument_type(["c"])), |
27 | 34 | ),
|
28 | 35 | timeout=1.0,
|
29 | 36 | )
|
30 | 37 |
|
31 |
| - assert result == [[10], [11], [12]] |
| 38 | + assert result == [argument_type([10]), argument_type([11]), argument_type([12])] |
32 | 39 |
|
33 | 40 | asyncio.run(run_test())
|
34 | 41 |
|
35 | 42 |
|
36 |
| -def test_single_items_kwargs_batchsize_reached(): |
| 43 | +@pytest.mark.parametrize("argument_type", [list, np.array]) |
| 44 | +def test_single_items_kwargs_batchsize_reached(argument_type): |
37 | 45 | @dike.batch(target_batch_size=3, max_waiting_time=10)
|
38 | 46 | async def f(arg1, arg2):
|
39 |
| - assert arg1 == [0, 1, 2] |
40 |
| - assert arg2 == ["a", "b", "c"] |
41 |
| - return [10, 11, 12] |
| 47 | + assert arg1 == argument_type([0, 1, 2]) |
| 48 | + assert arg2 == argument_type(["a", "b", "c"]) |
| 49 | + return argument_type([10, 11, 12]) |
42 | 50 |
|
43 | 51 | async def run_test():
|
44 | 52 | result = await asyncio.wait_for(
|
45 | 53 | asyncio.gather(
|
46 |
| - f(arg1=[0], arg2=["a"]), |
47 |
| - f(arg1=[1], arg2=["b"]), |
48 |
| - f(arg2=["c"], arg1=[2]), |
| 54 | + f(arg2=argument_type(["a"]), arg1=argument_type([0])), |
| 55 | + f(arg2=argument_type(["b"]), arg1=argument_type([1])), |
| 56 | + f(arg1=argument_type([2]), arg2=argument_type(["c"])), |
| 57 | + # f(arg2=argument_type(["c"]), arg1=argument_type([2])), |
49 | 58 | ),
|
50 | 59 | timeout=1.0,
|
51 | 60 | )
|
52 | 61 |
|
53 |
| - assert result == [[10], [11], [12]] |
| 62 | + assert result == [argument_type([10]), argument_type([11]), argument_type([12])] |
| 63 | + |
| 64 | + asyncio.run(run_test()) |
| 65 | + |
| 66 | + |
| 67 | +@pytest.mark.parametrize("argument_type", [list, np.array]) |
| 68 | +def test_single_items_mixed_kwargs_raises_value_error(argument_type): |
| 69 | + @dike.batch(target_batch_size=3, max_waiting_time=0.01) |
| 70 | + async def f(arg1, arg2): |
| 71 | + assert arg1 == argument_type([0, 1]) |
| 72 | + assert arg2 == argument_type(["a", "b"]) |
| 73 | + return argument_type([10, 11]) |
| 74 | + |
| 75 | + async def run_test(): |
| 76 | + result = await asyncio.wait_for( |
| 77 | + asyncio.gather( |
| 78 | + f(argument_type([0]), argument_type(["a"])), |
| 79 | + f(argument_type([1]), argument_type(["b"])), |
| 80 | + f(arg2=argument_type(["c"]), arg1=argument_type([2])), |
| 81 | + f(argument_type([1])), |
| 82 | + f(argument_type([]), argument_type([])), |
| 83 | + return_exceptions=True |
| 84 | + ), |
| 85 | + timeout=1.0, |
| 86 | + ) |
| 87 | + |
| 88 | + assert result[0] == argument_type([10]) |
| 89 | + assert result[1] == argument_type([11]) |
| 90 | + assert exceptions_equal( |
| 91 | + result[2], ValueError("Inconsistent use of positional and keyword arguments") |
| 92 | + ) |
| 93 | + assert exceptions_equal( |
| 94 | + result[3], ValueError("Inconsistent use of positional and keyword arguments") |
| 95 | + ) |
| 96 | + assert exceptions_equal( |
| 97 | + result[4], ValueError("Function called with empty collections as arguments") |
| 98 | + ) |
54 | 99 |
|
55 | 100 | asyncio.run(run_test())
|
56 | 101 |
|
@@ -193,12 +238,7 @@ async def f(arg1, arg2):
|
193 | 238 |
|
194 | 239 | async def run_test():
|
195 | 240 | results = await asyncio.wait_for(
|
196 |
| - asyncio.gather( |
197 |
| - f([0], ["a"]), |
198 |
| - f([1], ["b"]), |
199 |
| - f([2], ["c"]), |
200 |
| - return_exceptions=True |
201 |
| - ), |
| 241 | + asyncio.gather(f([0], ["a"]), f([1], ["b"]), f([2], ["c"]), return_exceptions=True), |
202 | 242 | timeout=1.0,
|
203 | 243 | )
|
204 | 244 | for r in results:
|
|
0 commit comments