|
3 | 3 | import secrets
|
4 | 4 | import logging
|
5 | 5 | import os
|
6 |
| - |
7 | 6 | from pydantic import BaseModel
|
8 | 7 | from transformers import pipeline
|
9 | 8 | from fastapi import FastAPI, HTTPException, status, Depends
|
|
12 | 11 | from dotenv import load_dotenv
|
13 | 12 | from cachier import cachier
|
14 | 13 |
|
| 14 | +# ------------------ SETUP ------------------ |
| 15 | + |
| 16 | +# Load environment variables |
15 | 17 | load_dotenv()
|
16 | 18 |
|
| 19 | +# Create FastAPI instance |
17 | 20 | app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None)
|
18 | 21 |
|
19 |
| -# auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH |
| 22 | +# Setup logging |
| 23 | +logging.basicConfig(level=logging.INFO) |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +# ------------------ AUTHENTICATION ------------------ |
| 27 | + |
| 28 | +# Auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH |
20 | 29 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
21 | 30 | API_KEY_HASH = os.getenv("API_KEY_HASH")
|
22 | 31 | if not API_KEY_HASH and os.path.exists("/run/secrets/api_key_hash"):
|
23 | 32 | with open("/run/secrets/api_key_hash", "r") as f:
|
24 | 33 | API_KEY_HASH = f.read().strip()
|
| 34 | + logger.info("API key hash loaded from secret") |
| 35 | +else: |
| 36 | + logger.info("API key hash loaded from environment variable") |
25 | 37 |
|
26 | 38 | assert API_KEY_HASH, "API_KEY_HASH must be set"
|
27 | 39 |
|
| 40 | + |
| 41 | +# Function to verify API key |
| 42 | +def verify_api_key(token: str): |
| 43 | + token_hash: str = hashlib.sha256(token.encode()).hexdigest() |
| 44 | + return secrets.compare_digest(token_hash, API_KEY_HASH) |
| 45 | + |
| 46 | + |
| 47 | +# Dependency to authenticate user |
| 48 | +async def authenticate_user(token: str = Depends(oauth2_scheme)): |
| 49 | + if not verify_api_key(token): |
| 50 | + raise HTTPException( |
| 51 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 52 | + detail="Invalid API Key", |
| 53 | + headers={"WWW-Authenticate": "Bearer"}, |
| 54 | + ) |
| 55 | + return token |
| 56 | + |
| 57 | + |
| 58 | +# ------------------ CLASSIFICATION ------------------ |
| 59 | + |
| 60 | +# Setup classifier |
28 | 61 | classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
29 | 62 |
|
| 63 | +# Default labels |
30 | 64 | DEFAULT_LABELS: list[str] = [
|
31 | 65 | "programming",
|
32 | 66 | "politics",
|
|
36 | 70 | "video games",
|
37 | 71 | ]
|
38 | 72 |
|
39 |
| -pool = ThreadPoolExecutor(max_workers=1) |
40 |
| - |
41 |
| -logging.basicConfig(level=logging.INFO) |
42 |
| - |
43 | 73 |
|
| 74 | +# Classification model |
44 | 75 | class Classification(BaseModel):
|
45 | 76 | sequence: str = "The text to classify"
|
46 | 77 | labels: list[str] = DEFAULT_LABELS
|
47 | 78 | scores: list[float] = [0.0] * len(DEFAULT_LABELS)
|
48 | 79 |
|
49 | 80 |
|
| 81 | +# Function to classify message |
50 | 82 | @cachier(cache_dir="./cache")
|
51 | 83 | def classify_sync(message: str, labels: list[str]) -> dict:
|
52 | 84 | result = classifier(message, candidate_labels=labels)
|
53 | 85 | return result
|
54 | 86 |
|
55 | 87 |
|
56 |
| -# setup auth |
57 |
| -def verify_api_key(token: str): |
58 |
| - token_hash: str = hashlib.sha256(token.encode()).hexdigest() |
59 |
| - return secrets.compare_digest(token_hash, API_KEY_HASH) |
60 |
| - |
61 |
| - |
62 |
| -async def authenticate_user(token: str = Depends(oauth2_scheme)): |
63 |
| - if not verify_api_key(token): |
64 |
| - raise HTTPException( |
65 |
| - status_code=status.HTTP_401_UNAUTHORIZED, |
66 |
| - detail="Invalid API Key", |
67 |
| - headers={"WWW-Authenticate": "Bearer"}, |
68 |
| - ) |
69 |
| - return token |
70 |
| - |
| 88 | +# ------------------ ROUTES ------------------ |
71 | 89 |
|
72 |
| -classification_lock = asyncio.Lock() # Ensure only one classification at a time |
| 90 | +# Lock to ensure only one classification at a time |
| 91 | +classification_lock = asyncio.Lock() |
73 | 92 |
|
74 | 93 |
|
| 94 | +# Route to classify message |
75 | 95 | @app.get("/v1/classify")
|
76 | 96 | async def classify(
|
77 | 97 | message: str, labels: list[str] = None, token: str = Depends(authenticate_user)
|
78 | 98 | ) -> Classification:
|
79 |
| - """ |
80 |
| - Classify the message into one of the labels |
81 |
| - :param message: The message to classify |
82 |
| - :type message: str |
83 |
| - :param labels: The labels to classify the message into |
84 |
| - :type labels: list[str] |
85 |
| - :return: The classification result |
86 |
| - :rtype: Classification |
87 |
| - """ |
88 | 99 | labels = labels or DEFAULT_LABELS
|
89 |
| - async with classification_lock: |
| 100 | + async with classification_lock: # Ensure only one classification at a time |
90 | 101 | loop = asyncio.get_event_loop()
|
91 | 102 | result = await loop.run_in_executor(None, classify_sync, message, labels)
|
92 | 103 | result = Classification(**result)
|
93 | 104 | return result
|
94 | 105 |
|
95 | 106 |
|
| 107 | +# Health check route |
96 | 108 | @app.get("/v1/health")
|
97 | 109 | async def health() -> dict:
|
98 | 110 | return {"status": "ok"}
|
0 commit comments