-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
122 lines (94 loc) · 3.45 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import asyncio
import hashlib
import secrets
import logging
import os
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from pydantic import BaseModel
from transformers import pipeline
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.requests import Request
from fastapi.security import OAuth2PasswordBearer
from dotenv import load_dotenv
from cachier import cachier
# ------------------ SETUP ------------------
# Load environment variables
load_dotenv()
# Create FastAPI instance
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
# Setup rate limiter
limiter = Limiter(key_func=lambda: "global", default_limits=["1000 per day"])
app.state.limiter = limiter # noqa
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # noqa
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------ AUTHENTICATION ------------------
# Auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
API_KEY_HASH = os.getenv("API_KEY_HASH")
if not API_KEY_HASH and os.path.exists("/run/secrets/api_key_hash"):
with open("/run/secrets/api_key_hash", "r") as f:
API_KEY_HASH = f.read().strip()
logger.info("API key hash loaded from secret")
else:
logger.info("API key hash loaded from environment variable")
assert API_KEY_HASH, "API_KEY_HASH must be set"
# Function to verify API key
def verify_api_key(token: str):
token_hash: str = hashlib.sha256(token.encode()).hexdigest()
return secrets.compare_digest(token_hash, API_KEY_HASH)
# Dependency to authenticate user
async def authenticate_user(token: str = Depends(oauth2_scheme)):
if not verify_api_key(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
headers={"WWW-Authenticate": "Bearer"},
)
return token
# ------------------ CLASSIFICATION ------------------
# Setup classifier
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Default labels
DEFAULT_LABELS: list[str] = [
"programming",
"politics",
"sports",
"science",
"technology",
"video games",
]
# Classification model
class Classification(BaseModel):
sequence: str = "The text to classify"
labels: list[str] = DEFAULT_LABELS
scores: list[float] = [0.0] * len(DEFAULT_LABELS)
# Function to classify message
@cachier(cache_dir="./cache")
def classify_sync(message: str, labels: list[str]) -> dict:
result = classifier(message, candidate_labels=labels)
return result
# ------------------ ROUTES ------------------
# Lock to ensure only one classification at a time
classification_lock = asyncio.Lock()
# Route to classify message
@app.get("/v1/classify")
@limiter.limit("6/second")
async def classify(
request: Request,
message: str,
labels: list[str] = None,
token: str = Depends(authenticate_user),
) -> Classification:
labels = labels or DEFAULT_LABELS
async with classification_lock: # Ensure only one classification at a time
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, classify_sync, message, labels)
result = Classification(**result)
return result
# Health check route
@app.get("/v1/health")
async def health() -> dict:
return {"status": "ok"}