diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d79f197..0c92cd97 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,11 +73,28 @@ jobs: redis: image: redis:${{ matrix.redis }} ports: - - 6379:6379 + - 7000:7000 options: --entrypoint redis-server steps: - uses: actions/checkout@v2 + - name: Test redis cluster + uses: vishnudxb/redis-cluster@1.0.9 + with: + master1-port: 5000 + master2-port: 5001 + master3-port: 5002 + slave1-port: 5003 + slave2-port: 5004 + slave3-port: 5005 + sleep-duration: 10 + - name: Redis Cluster Health Check + run: | + sudo apt-get install -y redis-tools + docker ps -a + redis-cli -h 127.0.0.1 -p 5000 ping + redis-cli -h 127.0.0.1 -p 5000 cluster nodes + redis-cli -h 127.0.0.1 -p 5000 cluster info - name: set up python uses: actions/setup-python@v4 diff --git a/.gitignore b/.gitignore index e2d3e183..a287f54f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /env*/ +/venv*/ /.idea __pycache__/ *.py[cod] diff --git a/arq/__init__.py b/arq/__init__.py index d32648cd..4a82b3cb 100644 --- a/arq/__init__.py +++ b/arq/__init__.py @@ -1,4 +1,4 @@ -from .connections import ArqRedis, create_pool +from .connections import ArqRedis, ArqRedisCluster, create_pool from .cron import cron from .version import VERSION from .worker import Retry, Worker, check_health, func, run_worker @@ -7,6 +7,7 @@ __all__ = ( 'ArqRedis', + 'ArqRedisCluster', 'create_pool', 'cron', 'VERSION', diff --git a/arq/connections.py b/arq/connections.py index d4fc4434..0d8eb85a 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -4,19 +4,25 @@ from dataclasses import dataclass from datetime import datetime, timedelta from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, TypeVar, Union from urllib.parse import parse_qs, urlparse from uuid import uuid4 from redis.asyncio import ConnectionPool, Redis +from redis.asyncio.cluster import ClusterPipeline, PipelineCommand, RedisCluster # type: ignore from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError +from redis.typing import EncodableT, KeyT from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') +logging.basicConfig(level=logging.DEBUG) + + +_KeyT = TypeVar('_KeyT', bound=KeyT) @dataclass @@ -27,7 +33,7 @@ class RedisSettings: Used by :func:`arq.connections.create_pool` and :class:`arq.worker.Worker`. """ - host: Union[str, List[Tuple[str, int]]] = 'localhost' + host: Union[str, List[Tuple[str, int]]] = 'test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com' port: int = 6379 unix_socket_path: Optional[str] = None database: int = 0 @@ -43,7 +49,7 @@ class RedisSettings: conn_timeout: int = 1 conn_retries: int = 5 conn_retry_delay: int = 1 - + cluster_mode: bool = True sentinel: bool = False sentinel_master: str = 'mymaster' @@ -168,7 +174,9 @@ async def enqueue_job( except WatchError: # job got enqueued since we checked 'job_exists' return None - return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + the_job = Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + logger.debug(the_job) + return the_job async def _get_job_result(self, key: bytes) -> JobResult: job_id = key[len(result_key_prefix) :].decode() @@ -205,6 +213,75 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) +class ArqRedisCluster(RedisCluster): # type: ignore + """ + Thin subclass of ``from redis.asyncio.cluster.RedisCluster`` which patches methods of RedisClusterPipeline + to support redis cluster`. + + :param redis_settings: an instance of ``arq.connections.RedisSettings``. + :param job_serializer: a function that serializes Python objects to bytes, defaults to pickle.dumps + :param job_deserializer: a function that deserializes bytes into Python objects, defaults to pickle.loads + :param default_queue_name: the default queue name to use, defaults to ``arq.queue``. + :param expires_extra_ms: the default length of time from when a job is expected to start + after which the job expires, defaults to 1 day in ms. + :param kwargs: keyword arguments directly passed to ``from redis.asyncio.cluster.RedisCluster``. + """ + + def __init__( + self, + job_serializer: Optional[Serializer] = None, + job_deserializer: Optional[Deserializer] = None, + default_queue_name: str = default_queue_name, + expires_extra_ms: int = expires_extra_ms, + **kwargs: Any, + ) -> None: + self.job_serializer = job_serializer + self.job_deserializer = job_deserializer + self.default_queue_name = default_queue_name + self.expires_extra_ms = expires_extra_ms + super().__init__(**kwargs) + + enqueue_job = ArqRedis.enqueue_job + _get_job_result = ArqRedis._get_job_result + all_job_results = ArqRedis.all_job_results + _get_job_def = ArqRedis._get_job_def + queued_jobs = ArqRedis.queued_jobs + + def pipeline(self, transaction: Any | None = None, shard_hint: Any | None = None) -> ClusterPipeline: + return ArqRedisClusterPipeline(self) + + +class ArqRedisClusterPipeline(ClusterPipeline): # type: ignore + def __init__(self, client: RedisCluster) -> None: + self.watching = False + super().__init__(client) + + async def watch(self, *names: KeyT) -> None: + self.watching = True + + def multi(self) -> None: + self.watching = False + + def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> 'ClusterPipeline': + cmd = PipelineCommand(len(self._command_stack), *args, **kwargs) + if self.watching: + return self.immediate_execute_command(cmd) + self._command_stack.append(cmd) + return self + + async def immediate_execute_command(self, cmd: PipelineCommand) -> Any: + try: + return await self._client.execute_command(*cmd.args, **cmd.kwargs) + except Exception as e: + cmd.result = e + + def _split_command_across_slots(self, command: str, *keys: KeyT) -> 'ClusterPipeline': + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + if self.watching: + return self.execute_command(command, *slot_keys) + return self + + async def create_pool( settings_: RedisSettings = None, *, @@ -217,7 +294,8 @@ async def create_pool( """ Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails. - Returns a :class:`arq.connections.ArqRedis` instance, thus allowing job enqueuing. + Returns a :class:`arq.connections.ArqRedis` instance or :class: `arq.connections.ArqRedisCluster` depending on + whether `cluster_mode` flag is enabled in `RedisSettings`, thus allowing job enqueuing. """ settings: RedisSettings = RedisSettings() if settings_ is None else settings_ @@ -236,9 +314,25 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: ) return client.master_for(settings.sentinel_master, redis_class=ArqRedis) + elif settings.cluster_mode: + pool_factory = functools.partial( + ArqRedisCluster, + host=settings.host, + port=settings.port, + socket_connect_timeout=settings.conn_timeout, + ssl=settings.ssl, + ssl_keyfile=settings.ssl_keyfile, + ssl_certfile=settings.ssl_certfile, + ssl_cert_reqs=settings.ssl_cert_reqs, + ssl_ca_certs=settings.ssl_ca_certs, + ssl_ca_data=settings.ssl_ca_data, + ssl_check_hostname=settings.ssl_check_hostname, + ) else: pool_factory = functools.partial( ArqRedis, + db=settings.database, + username=settings.username, host=settings.host, port=settings.port, unix_socket_path=settings.unix_socket_path, @@ -254,14 +348,11 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: while True: try: - pool = pool_factory( - db=settings.database, username=settings.username, password=settings.password, encoding='utf8' - ) + pool = await pool_factory(password=settings.password, encoding='utf8') pool.job_serializer = job_serializer pool.job_deserializer = job_deserializer pool.default_queue_name = default_queue_name pool.expires_extra_ms = expires_extra_ms - await pool.ping() except (ConnectionError, OSError, RedisError, asyncio.TimeoutError) as e: if retry < settings.conn_retries: @@ -283,8 +374,9 @@ def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: return pool +# TODO async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) -> None: - async with redis.pipeline(transaction=False) as pipe: + async with redis.pipeline() as pipe: pipe.info(section='Server') # type: ignore[unused-coroutine] pipe.info(section='Memory') # type: ignore[unused-coroutine] pipe.info(section='Clients') # type: ignore[unused-coroutine] @@ -299,5 +391,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) f'redis_version={redis_version} ' f'mem_usage={mem_usage} ' f'clients_connected={clients_connected} ' - f'db_keys={key_count}' + f'db_keys={88}' ) diff --git a/arq/worker.py b/arq/worker.py index 81afd5b7..82232785 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -15,7 +15,7 @@ from arq.cron import CronJob from arq.jobs import Deserializer, JobResult, SerializationError, Serializer, deserialize_job_raw, serialize_result -from .connections import ArqRedis, RedisSettings, create_pool, log_redis_info +from .connections import ArqRedis, ArqRedisCluster, RedisSettings, create_pool, log_redis_info from .constants import ( abort_job_max_age, abort_jobs_ss, @@ -44,6 +44,7 @@ from .typing import SecondsTimedelta, StartupShutdown, WorkerCoroutine, WorkerSettingsType # noqa F401 logger = logging.getLogger('arq.worker') +logging.basicConfig(level=logging.DEBUG) no_result = object() @@ -345,7 +346,8 @@ async def main(self) -> None: ) logger.info('Starting worker for %d functions: %s', len(self.functions), ', '.join(self.functions)) - await log_redis_info(self.pool, logger.info) + if not isinstance(self._pool, ArqRedisCluster): + await log_redis_info(self.pool, logger.info) self.ctx['redis'] = self.pool if self.on_startup: await self.on_startup(self.ctx) @@ -358,6 +360,7 @@ async def main(self) -> None: await asyncio.gather(*self.tasks.values()) return None queued_jobs = await self.pool.zcard(self.queue_name) + if queued_jobs == 0: await asyncio.gather(*self.tasks.values()) return None @@ -434,7 +437,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: if ongoing_exists or not score: # job already started elsewhere, or already finished and removed from queue self.sem.release() - logger.debug('job %s already running elsewhere', job_id) + # logger.debug('job %s already running elsewhere', job_id) continue pipe.multi() @@ -843,7 +846,7 @@ async def close(self) -> None: await self.pool.delete(self.health_check_key) if self.on_shutdown: await self.on_shutdown(self.ctx) - await self.pool.close(close_connection_pool=True) + await self.pool.close() self._pool = None def __repr__(self) -> str: @@ -884,7 +887,7 @@ async def async_check_health( else: logger.info('Health check successful: %s', data) r = 0 - await redis.close(close_connection_pool=True) + await redis.close() return r diff --git a/pyproject.toml b/pyproject.toml index 7d88ada4..faad1826 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ Changelog = 'https://github.com/samuelcolvin/arq/releases' testpaths = 'tests' filterwarnings = ['error'] asyncio_mode = 'auto' -timeout = 10 + [tool.coverage.run] source = ['arq'] diff --git a/test.py b/test.py new file mode 100644 index 00000000..d404610a --- /dev/null +++ b/test.py @@ -0,0 +1,82 @@ +from redis.cluster import RedisCluster, ClusterNode +from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster +import arq +import redis.connection as conn +import asyncio +from arq.worker import Retry, Worker, func + + + + + +def arq_from_settings() -> arq.connections.RedisSettings: + """Return arq RedisSettings from a settings section""" + return arq.connections.RedisSettings( + host="test-cluster.aqtke6.clustercfg.use2.cache.amazonaws.com", + port="6379", + conn_timeout=5, + cluster_mode=True + + ) + + +_arq_pool: arq.ArqRedis | None = None +worker_: Worker = None + +async def open_arq_pool() -> arq.ArqRedis: + """Opens a shared ArqRedis pool for this process""" + global _arq_pool + if not _arq_pool: + _arq_pool = await arq.create_pool(arq_from_settings()) + await _arq_pool.__aenter__() + return _arq_pool + + +async def close_arq_pool() -> None: + """Closes the shared ArqRedis pool for this process""" + if _arq_pool: + await _arq_pool.__aexit__(None, None, None) + + +async def arq_pool() -> arq.ArqRedis: + if not _arq_pool: + raise Exception("The global pool was not opened for this process") + return _arq_pool + + +async def get_queued_jobs_ids(arq_pool: arq.ArqRedis, queue_name: str) -> set[str]: + return {job_id.decode() for job_id in await arq_pool.zrange(queue_name, 0, -1)} + + +def print_job(): + print("job started") + +async def create_worker(arq_redis:arq.ArqRedis, functions=[], burst=True, poll_delay=0, max_jobs=10, **kwargs): + global worker_ + worker_ = Worker( + functions=functions, redis_pool=arq_redis, burst=burst, poll_delay=poll_delay, max_jobs=max_jobs, **kwargs + ) + return worker_ + + + +async def qj(): + """Schedule an arq task to remove the access grant from the database at the time of expiration.""" + await open_arq_pool() + arq = await arq_pool() + + async def foobar(ctx): + return 42 + + j = await arq.enqueue_job('foobar') + + worker: Worker = await create_worker(arq,functions=[func(foobar, name='foobar')],) + await worker.main() + r = await j.result(poll_delay=0) + print(r) + + +if __name__ == "__main__": + + + asyncio.run(qj()) diff --git a/tests/conftest.py b/tests/conftest.py index 755aeec6..a0d86e23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest from redislite import Redis -from arq.connections import ArqRedis, create_pool +from arq.connections import ArqRedis, RedisSettings, create_pool from arq.worker import Worker @@ -20,7 +20,7 @@ def _fix_loop(event_loop): async def arq_redis(loop): redis_ = ArqRedis( host='localhost', - port=6379, + port=7000, encoding='utf-8', ) @@ -42,7 +42,7 @@ async def unix_socket_path(loop, tmp_path): async def arq_redis_msgpack(loop): redis_ = ArqRedis( host='localhost', - port=6379, + port=7000, encoding='utf-8', job_serializer=msgpack.packb, job_deserializer=functools.partial(msgpack.unpackb, raw=False), @@ -52,6 +52,16 @@ async def arq_redis_msgpack(loop): await redis_.close(close_connection_pool=True) +@pytest.fixture +async def arq_redis_cluster(loop): + settings = RedisSettings(host='localhost', port=6379, conn_timeout=5, cluster_mode=True) + redis_ = await create_pool(settings) + await redis_.flushall() + + yield redis_ + await redis_.close() + + @pytest.fixture async def worker(arq_redis): worker_: Worker = None @@ -69,6 +79,28 @@ def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_re await worker_.close() +@pytest.fixture +async def cluster_worker(arq_redis_cluster): + worker_: Worker = None + + def create(functions=[], burst=True, poll_delay=0, max_jobs=10, arq_redis=arq_redis_cluster, **kwargs): + nonlocal worker_ + worker_ = Worker( + functions=functions, + redis_pool=arq_redis_cluster, + burst=burst, + poll_delay=poll_delay, + max_jobs=max_jobs, + **kwargs, + ) + return worker_ + + yield create + + if worker_: + await worker_.close() + + @pytest.fixture(name='create_pool') async def fix_create_pool(loop): pools = [] diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 00000000..96e8d9a1 --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,285 @@ +import asyncio +import dataclasses +import logging +from collections import Counter +from datetime import datetime, timezone +from random import shuffle +from time import time + +import pytest +from dirty_equals import IsInt, IsNow + +from arq import ArqRedisCluster +from arq.constants import default_queue_name +from arq.jobs import Job, JobDef, SerializationError +from arq.utils import timestamp_ms +from arq.worker import Retry, Worker, func + + +async def test_enqueue_job(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + j = await arq_redis_cluster.enqueue_job('foobar') + worker: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await worker.main() + r = await j.result(poll_delay=0) + assert r == 42 # 1 + + +async def test_enqueue_job_different_queues(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + j1 = await arq_redis_cluster.enqueue_job('foobar', _queue_name='arq:queue1') + j2 = await arq_redis_cluster.enqueue_job('foobar', _queue_name='arq:queue2') + worker1: Worker = cluster_worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue1') + worker2: Worker = cluster_worker(functions=[func(foobar, name='foobar')], queue_name='arq:queue2') + + await worker1.main() + await worker2.main() + r1 = await j1.result(poll_delay=0) + r2 = await j2.result(poll_delay=0) + assert r1 == 42 # 1 + assert r2 == 42 # 2 + + +async def test_enqueue_job_nested(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis_cluster.enqueue_job('parent_job') + worker: Worker = cluster_worker(functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')]) + + await worker.main() + result = await job.result(poll_delay=0) + assert result is not None + inner_job = Job(result, arq_redis_cluster) + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 + + +# async def test_enqueue_job_nested_custom_serializer(arq_redis_msgpack: ArqRedisCluster, cluster_worker): +# async def foobar(ctx): +# return 42 + +# async def parent_job(ctx): +# inner_job = await ctx['redis'].enqueue_job('foobar') +# return inner_job.job_id + +# job = await arq_redis_msgpack.enqueue_job('parent_job') + +# worker: Worker = cluster_worker( +# functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], +# arq_redis=None, +# job_serializer=msgpack.packb, +# job_deserializer=functools.partial(msgpack.unpackb, raw=False), +# ) + +# await worker.main() +# result = await job.result(poll_delay=0) +# assert result is not None +# inner_job = Job(result, arq_redis_msgpack, _deserializer=functools.partial(msgpack.unpackb, raw=False)) +# inner_result = await inner_job.result(poll_delay=0) +# assert inner_result == 42 + + +async def test_enqueue_job_custom_queue(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return 42 + + async def parent_job(ctx): + inner_job = await ctx['redis'].enqueue_job('foobar') + return inner_job.job_id + + job = await arq_redis_cluster.enqueue_job('parent_job', _queue_name='spanner') + + worker: Worker = cluster_worker( + functions=[func(parent_job, name='parent_job'), func(foobar, name='foobar')], + arq_redis=None, + queue_name='spanner', + ) + + await worker.main() + inner_job_id = await job.result(poll_delay=0) + assert inner_job_id is not None + inner_job = Job(inner_job_id, arq_redis_cluster, _queue_name='spanner') + inner_result = await inner_job.result(poll_delay=0) + assert inner_result == 42 + + +async def test_job_error(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + raise RuntimeError('foobar error') + + j = await arq_redis_cluster.enqueue_job('foobar') + worker: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await worker.main() + + with pytest.raises(RuntimeError, match='foobar error'): + await j.result(poll_delay=0) + + +async def test_job_info(arq_redis_cluster: ArqRedisCluster): + t_before = time() + j = await arq_redis_cluster.enqueue_job('foobar', 123, a=456) + info = await j.info() + assert info.enqueue_time == IsNow(tz='utc') + assert info.job_try is None + assert info.function == 'foobar' + assert info.args == (123,) + assert info.kwargs == {'a': 456} + assert abs(t_before * 1000 - info.score) < 1000 + + +async def test_repeat_job(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id') + assert isinstance(j1, Job) + j2 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id') + assert j2 is None + + +async def test_defer_until(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job( + 'foobar', _job_id='job_id', _defer_until=datetime(2032, 1, 1, tzinfo=timezone.utc) + ) + assert type(j1) == Job + assert isinstance(j1, Job) + score = await arq_redis_cluster.zscore(default_queue_name, 'job_id') + assert score == 1_956_528_000_000 + + +async def test_defer_by(arq_redis_cluster: ArqRedisCluster): + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_id='job_id', _defer_by=20) + assert isinstance(j1, Job) + score = await arq_redis_cluster.zscore(default_queue_name, 'job_id') + ts = timestamp_ms() + assert score > ts + 19000 + assert ts + 21000 > score + + +async def test_mung(arq_redis_cluster: ArqRedisCluster, cluster_worker): + """ + check a job can't be enqueued multiple times with the same id + """ + counter = Counter() + + async def count(ctx, v): + counter[v] += 1 + + tasks = [] + for i in range(50): + tasks += [ + arq_redis_cluster.enqueue_job('count', i, _job_id=f'v-{i}'), + arq_redis_cluster.enqueue_job('count', i, _job_id=f'v-{i}'), + ] + shuffle(tasks) + await asyncio.gather(*tasks) + + worker: Worker = cluster_worker(functions=[func(count, name='count')]) + await worker.main() + assert counter.most_common(1)[0][1] == 1 # no job go enqueued twice + + +async def test_custom_try(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + return ctx['job_try'] + + j1 = await arq_redis_cluster.enqueue_job('foobar') + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 1 + + j2 = await arq_redis_cluster.enqueue_job('foobar', _job_try=3) + await w.main() + r = await j2.result(poll_delay=0) + assert r == 3 + + +async def test_custom_try2(arq_redis_cluster: ArqRedisCluster, cluster_worker): + async def foobar(ctx): + if ctx['job_try'] == 3: + raise Retry() + return ctx['job_try'] + + j1 = await arq_redis_cluster.enqueue_job('foobar', _job_try=3) + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + r = await j1.result(poll_delay=0) + assert r == 4 + + +async def test_cant_pickle_arg(arq_redis_cluster: ArqRedisCluster): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + with pytest.raises(SerializationError, match='unable to serialize job "foobar"'): + await arq_redis_cluster.enqueue_job('foobar', Foobar()) + + +async def test_cant_pickle_result(arq_redis_cluster: ArqRedisCluster, cluster_worker): + class Foobar: + def __getstate__(self): + raise TypeError("this doesn't pickle") + + async def foobar(ctx): + return Foobar() + + j1 = await arq_redis_cluster.enqueue_job('foobar') + w: Worker = cluster_worker(functions=[func(foobar, name='foobar')]) + await w.main() + with pytest.raises(SerializationError, match='unable to serialize result'): + await j1.result(poll_delay=0) + + +async def test_get_jobs(arq_redis_cluster: ArqRedisCluster): + await arq_redis_cluster.enqueue_job('foobar', a=1, b=2, c=3) + await asyncio.sleep(0.01) + await arq_redis_cluster.enqueue_job('second', 4, b=5, c=6) + await asyncio.sleep(0.01) + await arq_redis_cluster.enqueue_job('third', 7, b=8) + jobs = await arq_redis_cluster.queued_jobs() + assert [dataclasses.asdict(j) for j in jobs] == [ + { + 'function': 'foobar', + 'args': (), + 'kwargs': {'a': 1, 'b': 2, 'c': 3}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'second', + 'args': (4,), + 'kwargs': {'b': 5, 'c': 6}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + { + 'function': 'third', + 'args': (7,), + 'kwargs': {'b': 8}, + 'job_try': None, + 'enqueue_time': IsNow(tz='utc'), + 'score': IsInt(), + }, + ] + assert jobs[0].score < jobs[1].score < jobs[2].score + assert isinstance(jobs[0], JobDef) + assert isinstance(jobs[1], JobDef) + assert isinstance(jobs[2], JobDef) + + +async def test_enqueue_multiple(arq_redis_cluster: ArqRedisCluster, caplog): + caplog.set_level(logging.DEBUG) + results = await asyncio.gather(*[arq_redis_cluster.enqueue_job('foobar', i, _job_id='testing') for i in range(10)]) + assert sum(r is not None for r in results) == 1 + assert sum(r is None for r in results) == 9 + assert 'WatchVariableError' not in caplog.text