Skip to content

Commit 288e61e

Browse files
committed
First commit
1 parent c2fed7f commit 288e61e

File tree

8 files changed

+249
-3
lines changed

8 files changed

+249
-3
lines changed

app.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import asyncio
2+
import hashlib
3+
import secrets
4+
import logging
5+
import os
6+
7+
from pydantic import BaseModel
8+
from transformers import pipeline
9+
from fastapi import FastAPI, HTTPException, status, Depends
10+
from fastapi.security import OAuth2PasswordBearer
11+
from concurrent.futures import ThreadPoolExecutor
12+
from dotenv import load_dotenv
13+
14+
load_dotenv()
15+
16+
app = FastAPI()
17+
18+
# auth with a bearer api key, whoose hash is stored in the environment variable API_KEY_HASH
19+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
20+
API_KEY_HASH = os.getenv("API_KEY_HASH")
21+
assert API_KEY_HASH, "API_KEY_HASH environment variable must be set"
22+
23+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
24+
25+
DEFAULT_LABELS: list[str] = [
26+
"programming",
27+
"politics",
28+
"sports",
29+
"science",
30+
"technology",
31+
"video games",
32+
]
33+
34+
pool = ThreadPoolExecutor(max_workers=1)
35+
36+
logging.basicConfig(level=logging.INFO)
37+
38+
39+
class Classification(BaseModel):
40+
sequence: str = "The text to classify"
41+
labels: list[str] = DEFAULT_LABELS
42+
scores: list[float] = [0.0] * len(DEFAULT_LABELS)
43+
44+
45+
def classify_sync(message: str, labels: list[str]) -> dict:
46+
result = classifier(message, candidate_labels=labels)
47+
return result
48+
49+
50+
# setup auth
51+
def verify_api_key(token: str):
52+
token_hash: str = hashlib.sha256(token.encode()).hexdigest()
53+
return secrets.compare_digest(token_hash, API_KEY_HASH)
54+
55+
56+
async def authenticate_user(token: str = Depends(oauth2_scheme)):
57+
if not verify_api_key(token):
58+
raise HTTPException(
59+
status_code=status.HTTP_401_UNAUTHORIZED,
60+
detail="Invalid API Key",
61+
headers={"WWW-Authenticate": "Bearer"},
62+
)
63+
return token
64+
65+
66+
classification_lock = asyncio.Lock() # Ensure only one classification at a time
67+
68+
69+
@app.get("/v1/classify")
70+
async def classify(
71+
message: str, labels: list[str] = None, token: str = Depends(authenticate_user)
72+
) -> Classification:
73+
"""
74+
Classify the message into one of the labels
75+
:param message: The message to classify
76+
:type message: str
77+
:param labels: The labels to classify the message into
78+
:type labels: list[str]
79+
:return: The classification result
80+
:rtype: Classification
81+
"""
82+
labels = labels or DEFAULT_LABELS
83+
async with classification_lock:
84+
loop = asyncio.get_event_loop()
85+
result = await loop.run_in_executor(None, classify_sync, message, labels)
86+
result = Classification(**result)
87+
return result
88+
89+
90+
@app.get("/v1/health")
91+
async def health() -> dict:
92+
return {"status": "ok"}

message_classifier/__init__.py

Whitespace-only changes.

pdm.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requiements.txt

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
annotated-types==0.7.0 ; python_version >= "3.11" and python_version < "4.0"
2+
anyio==4.4.0 ; python_version >= "3.11" and python_version < "4.0"
3+
certifi==2024.6.2 ; python_version >= "3.11" and python_version < "4.0"
4+
charset-normalizer==3.3.2 ; python_version >= "3.11" and python_version < "4.0"
5+
click==8.1.7 ; python_version >= "3.11" and python_version < "4.0"
6+
colorama==0.4.6 ; python_version >= "3.11" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows")
7+
dnspython==2.6.1 ; python_version >= "3.11" and python_version < "4.0"
8+
email-validator==2.1.2 ; python_version >= "3.11" and python_version < "4.0"
9+
fastapi-cli==0.0.4 ; python_version >= "3.11" and python_version < "4.0"
10+
fastapi==0.111.0 ; python_version >= "3.11" and python_version < "4.0"
11+
filelock==3.15.3 ; python_version >= "3.11" and python_version < "4.0"
12+
fsspec==2024.6.0 ; python_version >= "3.11" and python_version < "4.0"
13+
h11==0.14.0 ; python_version >= "3.11" and python_version < "4.0"
14+
httpcore==1.0.5 ; python_version >= "3.11" and python_version < "4.0"
15+
httptools==0.6.1 ; python_version >= "3.11" and python_version < "4.0"
16+
httpx==0.27.0 ; python_version >= "3.11" and python_version < "4.0"
17+
huggingface-hub==0.23.4 ; python_version >= "3.11" and python_version < "4.0"
18+
idna==3.7 ; python_version >= "3.11" and python_version < "4.0"
19+
intel-openmp==2021.4.0 ; python_version >= "3.11" and python_version < "4.0" and platform_system == "Windows"
20+
jinja2==3.1.4 ; python_version >= "3.11" and python_version < "4.0"
21+
markdown-it-py==3.0.0 ; python_version >= "3.11" and python_version < "4.0"
22+
markupsafe==2.1.5 ; python_version >= "3.11" and python_version < "4.0"
23+
mdurl==0.1.2 ; python_version >= "3.11" and python_version < "4.0"
24+
mkl==2021.4.0 ; python_version >= "3.11" and python_version < "4.0" and platform_system == "Windows"
25+
mpmath==1.3.0 ; python_version >= "3.11" and python_version < "4.0"
26+
networkx==3.3 ; python_version >= "3.11" and python_version < "4.0"
27+
numpy==1.26.4 ; python_version >= "3.11" and python_version < "4.0"
28+
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
29+
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
30+
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
31+
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
32+
nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
33+
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
34+
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
35+
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
36+
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
37+
nvidia-nccl-cu12==2.20.5 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
38+
nvidia-nvjitlink-cu12==12.5.40 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
39+
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.11" and python_version < "4.0"
40+
orjson==3.10.5 ; python_version >= "3.11" and python_version < "4.0"
41+
packaging==24.1 ; python_version >= "3.11" and python_version < "4.0"
42+
pybind11==2.12.0 ; python_version >= "3.11" and python_version < "4.0"
43+
pydantic-core==2.18.4 ; python_version >= "3.11" and python_version < "4.0"
44+
pydantic==2.7.4 ; python_version >= "3.11" and python_version < "4.0"
45+
pygments==2.18.0 ; python_version >= "3.11" and python_version < "4.0"
46+
python-dotenv==1.0.1 ; python_version >= "3.11" and python_version < "4.0"
47+
python-multipart==0.0.9 ; python_version >= "3.11" and python_version < "4.0"
48+
pyyaml==6.0.1 ; python_version >= "3.11" and python_version < "4.0"
49+
regex==2024.5.15 ; python_version >= "3.11" and python_version < "4.0"
50+
requests==2.32.3 ; python_version >= "3.11" and python_version < "4.0"
51+
rich==13.7.1 ; python_version >= "3.11" and python_version < "4.0"
52+
safetensors==0.4.3 ; python_version >= "3.11" and python_version < "4.0"
53+
shellingham==1.5.4 ; python_version >= "3.11" and python_version < "4.0"
54+
sniffio==1.3.1 ; python_version >= "3.11" and python_version < "4.0"
55+
starlette==0.37.2 ; python_version >= "3.11" and python_version < "4.0"
56+
sympy==1.12.1 ; python_version >= "3.11" and python_version < "4.0"
57+
tbb==2021.12.0 ; python_version >= "3.11" and python_version < "4.0" and platform_system == "Windows"
58+
tokenizers==0.19.1 ; python_version >= "3.11" and python_version < "4.0"
59+
torch==2.3.1 ; python_version >= "3.11" and python_version < "4.0"
60+
tqdm==4.66.4 ; python_version >= "3.11" and python_version < "4.0"
61+
transformers==4.41.2 ; python_version >= "3.11" and python_version < "4.0"
62+
triton==2.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" and python_version >= "3.11"
63+
typer==0.12.3 ; python_version >= "3.11" and python_version < "4.0"
64+
typing-extensions==4.12.2 ; python_version >= "3.11" and python_version < "4.0"
65+
ujson==5.10.0 ; python_version >= "3.11" and python_version < "4.0"
66+
urllib3==2.2.2 ; python_version >= "3.11" and python_version < "4.0"
67+
uvicorn[standard]==0.30.1 ; python_version >= "3.11" and python_version < "4.0"
68+
uvloop==0.19.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.11" and python_version < "4.0"
69+
watchfiles==0.22.0 ; python_version >= "3.11" and python_version < "4.0"
70+
websockets==12.0 ; python_version >= "3.11" and python_version < "4.0"

requirements.txt

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# This file is @generated by PDM.
2+
# Please do not edit it manually.
3+
4+
annotated-types==0.7.0
5+
anyio==4.4.0
6+
certifi==2024.6.2
7+
charset-normalizer==3.3.2
8+
click==8.1.7
9+
colorama==0.4.6; sys_platform == "win32" or platform_system == "Windows"
10+
dnspython==2.6.1
11+
email-validator==2.2.0
12+
fastapi==0.111.0
13+
fastapi-cli==0.0.4
14+
filelock==3.15.3
15+
fsspec==2024.6.0
16+
h11==0.14.0
17+
httpcore==1.0.5
18+
httptools==0.6.1
19+
httpx==0.27.0
20+
huggingface-hub==0.23.4
21+
idna==3.7
22+
intel-openmp==2021.4.0; platform_system == "Windows"
23+
jinja2==3.1.4
24+
markdown-it-py==3.0.0
25+
markupsafe==2.1.5
26+
mdurl==0.1.2
27+
mkl==2021.4.0; platform_system == "Windows"
28+
mpmath==1.3.0
29+
networkx==3.3
30+
numpy==1.26.4
31+
nvidia-cublas-cu12==12.1.3.1; platform_system == "Linux" and platform_machine == "x86_64"
32+
nvidia-cuda-cupti-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
33+
nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
34+
nvidia-cuda-runtime-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
35+
nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64"
36+
nvidia-cufft-cu12==11.0.2.54; platform_system == "Linux" and platform_machine == "x86_64"
37+
nvidia-curand-cu12==10.3.2.106; platform_system == "Linux" and platform_machine == "x86_64"
38+
nvidia-cusolver-cu12==11.4.5.107; platform_system == "Linux" and platform_machine == "x86_64"
39+
nvidia-cusparse-cu12==12.1.0.106; platform_system == "Linux" and platform_machine == "x86_64"
40+
nvidia-nccl-cu12==2.20.5; platform_system == "Linux" and platform_machine == "x86_64"
41+
nvidia-nvjitlink-cu12==12.5.40; platform_system == "Linux" and platform_machine == "x86_64"
42+
nvidia-nvtx-cu12==12.1.105; platform_system == "Linux" and platform_machine == "x86_64"
43+
orjson==3.10.5
44+
packaging==24.1
45+
pydantic==2.7.4
46+
pydantic-core==2.18.4
47+
pygments==2.18.0
48+
python-dotenv==1.0.1
49+
python-multipart==0.0.9
50+
pyyaml==6.0.1
51+
regex==2024.5.15
52+
requests==2.32.3
53+
rich==13.7.1
54+
safetensors==0.4.3
55+
shellingham==1.5.4
56+
sniffio==1.3.1
57+
starlette==0.37.2
58+
sympy==1.12.1
59+
tbb==2021.12.0; platform_system == "Windows"
60+
tokenizers==0.19.1
61+
torch==2.3.1
62+
tqdm==4.66.4
63+
transformers==4.41.2
64+
triton==2.3.1; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12"
65+
typer==0.12.3
66+
typing-extensions==4.12.2
67+
ujson==5.10.0
68+
urllib3==2.2.2
69+
uvicorn==0.30.1
70+
uvloop==0.19.0; (sys_platform != "cygwin" and sys_platform != "win32") and platform_python_implementation != "PyPy"
71+
watchfiles==0.22.0
72+
websockets==12.0

src/__init__.py

Whitespace-only changes.

test_main.http

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Test your FastAPI endpoints
2+
3+
GET http://127.0.0.1:8000/classify
4+
Accept: application/json
5+
6+
7+
###
8+
9+
GET http://127.0.0.1:8000/hello/User
10+
Accept: application/json
11+
12+
###

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)