Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change RedisBackend to accept Redis client directly #755

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Migration instructions

There are a number of backwards-incompatible changes. These points should help with migrating from an older release:

* ``RedisBackend`` now expects a ``redis.Redis`` instance as an argument, instead of creating one internally from keyword arguments.
* The ``key_builder`` parameter for caches now expects a callback which accepts 2 strings and returns a string in all cache implementations, making the builders simpler and interchangeable.
* The ``key`` parameter has been removed from the ``cached`` decorator. The behaviour can be easily reimplemented with ``key_builder=lambda *a, **kw: "foo"``
* When using the ``key_builder`` parameter in ``@multicached``, the function will now return the original, unmodified keys, only using the transformed keys in the cache (this has always been the documented behaviour, but not the implemented behaviour).
Expand Down
8 changes: 4 additions & 4 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@


class MemcachedBackend(BaseCache[bytes]):
def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs):
def __init__(self, host="127.0.0.1", port=11211, pool_size=2, **kwargs):
super().__init__(**kwargs)
self.endpoint = endpoint
self.host = host
self.port = port
self.pool_size = int(pool_size)
self.client = aiomcache.Client(
self.endpoint, self.port, pool_size=self.pool_size
self.host, self.port, pool_size=self.pool_size
)

async def _get(self, key, encoding="utf-8", _conn=None):
Expand Down Expand Up @@ -153,4 +153,4 @@ def parse_uri_path(cls, path):
return {}

def __repr__(self): # pragma: no cover
return "MemcachedCache ({}:{})".format(self.endpoint, self.port)
return "MemcachedCache ({}:{})".format(self.host, self.port)
46 changes: 9 additions & 37 deletions aiocache/backends/redis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import warnings
from typing import Any, Callable, Optional, TYPE_CHECKING

import redis.asyncio as redis
Expand Down Expand Up @@ -38,41 +37,19 @@ class RedisBackend(BaseCache[str]):

def __init__(
self,
endpoint="127.0.0.1",
port=6379,
db=0,
password=None,
pool_min_size=_NOT_SET,
pool_max_size=None,
create_connection_timeout=None,
client: redis.Redis,
**kwargs,
):
super().__init__(**kwargs)
if pool_min_size is not _NOT_SET:
warnings.warn(
"Parameter 'pool_min_size' is deprecated since aiocache 0.12",
DeprecationWarning, stacklevel=2
)
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved

self.endpoint = endpoint
self.port = int(port)
self.db = int(db)
self.password = password
# TODO: Remove int() call some time after adding type annotations.
self.pool_max_size = None if pool_max_size is None else int(pool_max_size)
self.create_connection_timeout = (
float(create_connection_timeout) if create_connection_timeout else None
)

# NOTE: decoding can't be controlled on API level after switching to
# redis, we need to disable decoding on global/connection level
# (decode_responses=False), because some of the values are saved as
# bytes directly, like pickle serialized values, which may raise an
# exception when decoded with 'utf-8'.
self.client = redis.Redis(host=self.endpoint, port=self.port, db=self.db,
password=self.password, decode_responses=False,
socket_connect_timeout=self.create_connection_timeout,
max_connections=self.pool_max_size)
if client.connection_pool.connection_kwargs['decode_responses']:
raise ValueError("redis client must be constructed with decode_responses set to False")
self.client = client

async def _get(self, key, encoding="utf-8", _conn=None):
value = await self.client.get(key)
Expand Down Expand Up @@ -175,9 +152,6 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
async def _redlock_release(self, key, value):
return await self._raw("eval", self.RELEASE_SCRIPT, 1, key, value)

async def _close(self, *args, _conn=None, **kwargs):
await self.client.close()

Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)

Expand All @@ -196,24 +170,21 @@ class RedisCache(RedisBackend):
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param endpoint: str with the endpoint to connect to. Default is "127.0.0.1".
:param port: int with the port to connect to. Default is 6379.
:param db: int indicating database to use. Default is 0.
:param password: str indicating password to use. Default is None.
:param pool_max_size: int maximum pool size for the redis connections pool. Default is None.
:param create_connection_timeout: int timeout for the creation of connection. Default is None
:param client: redis.Redis which is an active client for working with redis
"""

NAME = "redis"

def __init__(
self,
client: redis.Redis,
serializer: Optional["BaseSerializer"] = None,
namespace: str = "",
key_builder: Callable[[str, str], str] = lambda k, ns: f"{ns}:{k}" if ns else k,
**kwargs: Any,
):
super().__init__(
client=client,
serializer=serializer or JsonSerializer(),
namespace=namespace,
key_builder=key_builder,
Expand All @@ -237,4 +208,5 @@ def parse_uri_path(cls, path):
return options

def __repr__(self): # pragma: no cover
return "RedisCache ({}:{})".format(self.endpoint, self.port)
connection_kwargs = self.client.connection_pool.connection_kwargs
return "RedisCache ({}:{})".format(connection_kwargs['host'], connection_kwargs['port'])
30 changes: 24 additions & 6 deletions aiocache/factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
import urllib
from contextlib import suppress
from copy import deepcopy
from typing import Dict

from aiocache import AIOCACHE_CACHES
from aiocache.base import BaseCache
from aiocache.exceptions import InvalidCacheType

with suppress(ImportError):
import redis.asyncio as redis


logger = logging.getLogger(__name__)

Expand All @@ -18,6 +22,7 @@ def _class_from_string(class_path):


def _create_cache(cache, serializer=None, plugins=None, **kwargs):
kwargs = deepcopy(kwargs)
if serializer is not None:
cls = serializer.pop("class")
cls = _class_from_string(cls) if isinstance(cls, str) else cls
Expand All @@ -29,10 +34,17 @@ def _create_cache(cache, serializer=None, plugins=None, **kwargs):
cls = plugin.pop("class")
cls = _class_from_string(cls) if isinstance(cls, str) else cls
plugins_instances.append(cls(**plugin))

cache = _class_from_string(cache) if isinstance(cache, str) else cache
instance = cache(serializer=serializer, plugins=plugins_instances, **kwargs)
return instance
if cache == AIOCACHE_CACHES.get("redis"):
return cache(
serializer=serializer,
plugins=plugins_instances,
namespace=kwargs.pop('namespace', ''),

Check failure

Code scanning / CodeQL

Modification of parameter with default

This expression mutates a [default value](1).
ttl=kwargs.pop('ttl', None),

Check failure

Code scanning / CodeQL

Modification of parameter with default

This expression mutates a [default value](1).
client=redis.Redis(**kwargs)
)
else:
return cache(serializer=serializer, plugins=plugins_instances, **kwargs)
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved


class Cache:
Expand Down Expand Up @@ -112,15 +124,21 @@ def from_url(cls, url):
kwargs.update(cache_class.parse_uri_path(parsed_url.path))

if parsed_url.hostname:
kwargs["endpoint"] = parsed_url.hostname
kwargs["host"] = parsed_url.hostname

if parsed_url.port:
kwargs["port"] = parsed_url.port

if parsed_url.password:
kwargs["password"] = parsed_url.password

return Cache(cache_class, **kwargs)
for arg in ['max_connections', 'socket_connect_timeout']:
if arg in kwargs:
kwargs[arg] = int(kwargs[arg])
if cache_class == cls.REDIS:
return Cache(cache_class, client=redis.Redis(**kwargs))
else:
return Cache(cache_class, **kwargs)
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved


class CacheHandler:
Expand Down Expand Up @@ -214,7 +232,7 @@ def set_config(self, config):
},
'redis_alt': {
'cache': "aiocache.RedisCache",
'endpoint': "127.0.0.10",
'host': "127.0.0.10",
'port': 6378,
'serializer': {
'class': "aiocache.serializers.PickleSerializer"
Expand Down
15 changes: 9 additions & 6 deletions examples/cached_alias_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import redis.asyncio as redis

from aiocache import caches, Cache
from aiocache.serializers import StringSerializer, PickleSerializer

Expand All @@ -12,9 +14,9 @@
},
'redis_alt': {
'cache': "aiocache.RedisCache",
'endpoint': "127.0.0.1",
"host": "127.0.0.1",
'port': 6379,
'timeout': 1,
"socket_connect_timeout": 1,
'serializer': {
'class': "aiocache.serializers.PickleSerializer"
},
Expand Down Expand Up @@ -45,17 +47,18 @@ async def alt_cache():
assert isinstance(cache, Cache.REDIS)
assert isinstance(cache.serializer, PickleSerializer)
assert len(cache.plugins) == 2
assert cache.endpoint == "127.0.0.1"
assert cache.timeout == 1
assert cache.port == 6379
connection_args = cache.client.connection_pool.connection_kwargs
assert connection_args["host"] == "127.0.0.1"
assert connection_args["socket_connect_timeout"] == 1
assert connection_args["port"] == 6379
await cache.close()


async def test_alias():
await default_cache()
await alt_cache()

cache = Cache(Cache.REDIS)
cache = Cache(Cache.REDIS, client=redis.Redis())
await cache.delete("key")
await cache.close()

Expand Down
5 changes: 3 additions & 2 deletions examples/cached_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from collections import namedtuple
import redis.asyncio as redis

from aiocache import cached, Cache
from aiocache.serializers import PickleSerializer
Expand All @@ -10,13 +11,13 @@

@cached(
ttl=10, cache=Cache.REDIS, key_builder=lambda *args, **kw: "key",
serializer=PickleSerializer(), port=6379, namespace="main")
serializer=PickleSerializer(), namespace="main", client=redis.Redis())
async def cached_call():
return Result("content", 200)


async def test_cached():
async with Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main") as cache:
async with Cache(Cache.REDIS, namespace="main", client=redis.Redis()) as cache:
await cached_call()
exists = await cache.exists("key")
assert exists is True
Expand Down
11 changes: 6 additions & 5 deletions examples/multicached_decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio

import redis.asyncio as redis

from aiocache import multi_cached, Cache

DICT = {
Expand All @@ -9,20 +11,19 @@
'd': "W"
}

cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


@multi_cached("ids", cache=Cache.REDIS, namespace="main")
@multi_cached("ids", cache=Cache.REDIS, namespace="main", client=cache.client)
async def multi_cached_ids(ids=None):
return {id_: DICT[id_] for id_ in ids}


@multi_cached("keys", cache=Cache.REDIS, namespace="main")
@multi_cached("keys", cache=Cache.REDIS, namespace="main", client=cache.client)
async def multi_cached_keys(keys=None):
return {id_: DICT[id_] for id_ in keys}


cache = Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main")


async def test_multi_cached():
await multi_cached_ids(ids=("a", "b"))
await multi_cached_ids(ids=("a", "c"))
Expand Down
5 changes: 3 additions & 2 deletions examples/optimistic_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import logging
import random

import redis.asyncio as redis

from aiocache import Cache
from aiocache.lock import OptimisticLock, OptimisticLockError


logger = logging.getLogger(__name__)
cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main')
cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


async def expensive_function():
Expand Down
6 changes: 4 additions & 2 deletions examples/python_object.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio

from collections import namedtuple
import redis.asyncio as redis


from aiocache import Cache
from aiocache.serializers import PickleSerializer


MyObject = namedtuple("MyObject", ["x", "y"])
cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main")
cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main", client=redis.Redis())


async def complex_object():
Expand Down
5 changes: 3 additions & 2 deletions examples/redlock.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import logging

import redis.asyncio as redis

from aiocache import Cache
from aiocache.lock import RedLock


logger = logging.getLogger(__name__)
cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main')
cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


async def expensive_function():
Expand Down
4 changes: 3 additions & 1 deletion examples/serializer_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import zlib

import redis.asyncio as redis

from aiocache import Cache
from aiocache.serializers import BaseSerializer

Expand All @@ -25,7 +27,7 @@ def loads(self, value):
return decompressed


cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main")
cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main", client=redis.Redis())


async def serializer():
Expand Down
4 changes: 3 additions & 1 deletion examples/serializer_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json

import redis.asyncio as redis

from marshmallow import Schema, fields, post_load
from aiocache import Cache

Expand Down Expand Up @@ -28,7 +30,7 @@ def loads(value):
return MyTypeSchema().loads(value)


cache = Cache(Cache.REDIS, namespace="main")
cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis())


async def serializer_function():
Expand Down
Loading