Skip to content

Commit

Permalink
update redis calls to async/await (#242) (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
russbiggs authored Sep 7, 2023
1 parent f3e9a24 commit 1a73ed1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
11 changes: 7 additions & 4 deletions openaq_api/openaq_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,15 @@ def render(self, content: Any) -> bytes:

redis_client = None # initialize for generalize_schema.py


if settings.RATE_LIMITING:
logger.debug("Connecting to redis")
import redis
from redis.asyncio.cluster import RedisCluster

try:
redis_client = redis.RedisCluster(
redis_client = RedisCluster(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
decode_responses=True,
skip_full_coverage_check=True,
socket_timeout=5,
)
app.state.redis_client = redis_client
Expand Down Expand Up @@ -224,6 +222,11 @@ async def shutdown_event():
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")
if hasattr(app.state, "redis_client") and settings.RATE_LIMITING:
logger.debug("Closing redis connection")
await app.state.redis_client.close()
delattr(app.state, "redis_client")
logger.debug("redis connection closed")


@app.get("/ping", include_in_schema=False)
Expand Down
28 changes: 15 additions & 13 deletions openaq_api/openaq_api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fastapi import Response, status
from fastapi.responses import JSONResponse
from redis import Redis
from redis.asyncio.cluster import RedisCluster
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp
Expand Down Expand Up @@ -103,32 +103,33 @@ class RateLimiterMiddleWare(BaseHTTPMiddleware):
def __init__(
self,
app: ASGIApp,
redis_client: Redis,
redis_client: RedisCluster,
rate_amount: int, # number of requests allowed without api key
rate_amount_key: int, # number of requests allowed with api key
rate_time: timedelta, # timedelta of rate limit expiration
) -> None:
"""Init Middleware."""
super().__init__(app)
self.redis_client = redis_client
self.rate_amount = rate_amount
self.rate_amount_key = rate_amount_key
self.rate_amount = rate_amount # 100
self.rate_amount_key = rate_amount_key # 400
self.rate_time = rate_time
self.counter = 0

def request_is_limited(self, key: str, limit: int):
if self.redis_client.setnx(key, limit):
self.redis_client.expire(key, int(self.rate_time.total_seconds()))
count = self.redis_client.get(key)
async def request_is_limited(self, key: str, limit: int):
if await self.redis_client.set(key, limit, nx=True):
await self.redis_client.expire(key, int(self.rate_time.total_seconds()))
count = await self.redis_client.get(key)
if count and int(count) > 0:
self.counter = self.redis_client.decrby(key, 1)
self.counter = await self.redis_client.decrby(key, 1)
return False
if int(count) < 0:
self.redis_client.delete(key)
logger.error(f"rate limiter hit a value below zero: {count} for key: {key}")
await self.redis_client.delete(key)
return True

def check_valid_key(self, key: str):
if self.redis_client.sismember("keys", key):
async def check_valid_key(self, key: str):
if await self.redis_client.sismember("keys", key):
return True
return False

Expand Down Expand Up @@ -166,7 +167,8 @@ async def dispatch(
)
key = auth
limit = self.rate_amount_key
if self.limited_path(route) and self.request_is_limited(key, limit):
limited = await self.request_is_limited(key, limit)
if self.limited_path(route) and limited:
logging.info(
TooManyRequestsLog(
request=request,
Expand Down
2 changes: 1 addition & 1 deletion openaq_api/openaq_api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()):
token = await db.get_user_token(row[0])
if request.app.state.redis_client:
redis_client = request.app.state.redis_client
redis_client.sadd("keys", token)
await redis_client.sadd("keys", token)
send_api_key_email(token, row[3], row[4])
return templates.TemplateResponse(
"verify/index.html", {"request": request, "error": False, "verify": True}
Expand Down

0 comments on commit 1a73ed1

Please sign in to comment.