Skip to content

Commit 2b1224e

Browse files
committed
🎨 Format code in sections
1 parent 743a918 commit 2b1224e

File tree

1 file changed

+44
-32
lines changed

1 file changed

+44
-32
lines changed

app.py

+44-32
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import secrets
44
import logging
55
import os
6-
76
from pydantic import BaseModel
87
from transformers import pipeline
98
from fastapi import FastAPI, HTTPException, status, Depends
@@ -12,21 +11,56 @@
1211
from dotenv import load_dotenv
1312
from cachier import cachier
1413

14+
# ------------------ SETUP ------------------
15+
16+
# Load environment variables
1517
load_dotenv()
1618

19+
# Create FastAPI instance
1720
app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
1821

19-
# auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH
22+
# Setup logging
23+
logging.basicConfig(level=logging.INFO)
24+
logger = logging.getLogger(__name__)
25+
26+
# ------------------ AUTHENTICATION ------------------
27+
28+
# Auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH
2029
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
2130
API_KEY_HASH = os.getenv("API_KEY_HASH")
2231
if not API_KEY_HASH and os.path.exists("/run/secrets/api_key_hash"):
2332
with open("/run/secrets/api_key_hash", "r") as f:
2433
API_KEY_HASH = f.read().strip()
34+
logger.info("API key hash loaded from secret")
35+
else:
36+
logger.info("API key hash loaded from environment variable")
2537

2638
assert API_KEY_HASH, "API_KEY_HASH must be set"
2739

40+
41+
# Function to verify API key
42+
def verify_api_key(token: str):
43+
token_hash: str = hashlib.sha256(token.encode()).hexdigest()
44+
return secrets.compare_digest(token_hash, API_KEY_HASH)
45+
46+
47+
# Dependency to authenticate user
48+
async def authenticate_user(token: str = Depends(oauth2_scheme)):
49+
if not verify_api_key(token):
50+
raise HTTPException(
51+
status_code=status.HTTP_401_UNAUTHORIZED,
52+
detail="Invalid API Key",
53+
headers={"WWW-Authenticate": "Bearer"},
54+
)
55+
return token
56+
57+
58+
# ------------------ CLASSIFICATION ------------------
59+
60+
# Setup classifier
2861
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
2962

63+
# Default labels
3064
DEFAULT_LABELS: list[str] = [
3165
"programming",
3266
"politics",
@@ -36,63 +70,41 @@
3670
"video games",
3771
]
3872

39-
pool = ThreadPoolExecutor(max_workers=1)
40-
41-
logging.basicConfig(level=logging.INFO)
42-
4373

74+
# Classification model
4475
class Classification(BaseModel):
4576
sequence: str = "The text to classify"
4677
labels: list[str] = DEFAULT_LABELS
4778
scores: list[float] = [0.0] * len(DEFAULT_LABELS)
4879

4980

81+
# Function to classify message
5082
@cachier(cache_dir="./cache")
5183
def classify_sync(message: str, labels: list[str]) -> dict:
5284
result = classifier(message, candidate_labels=labels)
5385
return result
5486

5587

56-
# setup auth
57-
def verify_api_key(token: str):
58-
token_hash: str = hashlib.sha256(token.encode()).hexdigest()
59-
return secrets.compare_digest(token_hash, API_KEY_HASH)
60-
61-
62-
async def authenticate_user(token: str = Depends(oauth2_scheme)):
63-
if not verify_api_key(token):
64-
raise HTTPException(
65-
status_code=status.HTTP_401_UNAUTHORIZED,
66-
detail="Invalid API Key",
67-
headers={"WWW-Authenticate": "Bearer"},
68-
)
69-
return token
70-
88+
# ------------------ ROUTES ------------------
7189

72-
classification_lock = asyncio.Lock() # Ensure only one classification at a time
90+
# Lock to ensure only one classification at a time
91+
classification_lock = asyncio.Lock()
7392

7493

94+
# Route to classify message
7595
@app.get("/v1/classify")
7696
async def classify(
7797
message: str, labels: list[str] = None, token: str = Depends(authenticate_user)
7898
) -> Classification:
79-
"""
80-
Classify the message into one of the labels
81-
:param message: The message to classify
82-
:type message: str
83-
:param labels: The labels to classify the message into
84-
:type labels: list[str]
85-
:return: The classification result
86-
:rtype: Classification
87-
"""
8899
labels = labels or DEFAULT_LABELS
89-
async with classification_lock:
100+
async with classification_lock: # Ensure only one classification at a time
90101
loop = asyncio.get_event_loop()
91102
result = await loop.run_in_executor(None, classify_sync, message, labels)
92103
result = Classification(**result)
93104
return result
94105

95106

107+
# Health check route
96108
@app.get("/v1/health")
97109
async def health() -> dict:
98110
return {"status": "ok"}

0 commit comments

Comments
 (0)