-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b3debb5
commit e7b82a4
Showing
288 changed files
with
9,793 additions
and
11,251 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .http import auth_router | ||
from .websocket import identify_router |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
-*- coding: utf-8 -*- | ||
@Organization : SupaVision | ||
@Author : 18317 | ||
@Date Created : 05/01/2024 | ||
@Description : | ||
""" | ||
|
||
import logging | ||
from typing import Annotated | ||
|
||
from fastapi import Depends, HTTPException | ||
from fastapi.security import OAuth2PasswordBearer | ||
from gotrue.errors import AuthApiError | ||
from supabase_py_async import AsyncClient, create_client | ||
from supabase_py_async.lib.client_options import ClientOptions | ||
|
||
from ..core.config import settings | ||
from ..schemas.auth import UserIn | ||
|
||
super_client: AsyncClient | None = None | ||
|
||
|
||
async def init_super_client() -> None: | ||
"""for validation access_token init at life span event""" | ||
global super_client | ||
super_client = await create_client( | ||
settings.SUPABASE_URL, | ||
settings.SUPABASE_KEY, | ||
options=ClientOptions(postgrest_client_timeout=10, storage_client_timeout=10), | ||
) | ||
# await super_client.auth.sign_in_with_password( | ||
# {"email": settings.SUPERUSER_EMAIL, "password": settings.SUPERUSER_PASSWORD} | ||
# ) | ||
|
||
|
||
# auto get access_token from header | ||
reusable_oauth2 = OAuth2PasswordBearer( | ||
tokenUrl="please login by supabase-js to get token" | ||
) | ||
AccessTokenDep = Annotated[str, Depends(reusable_oauth2)] | ||
|
||
|
||
async def get_current_user(access_token: AccessTokenDep) -> UserIn: | ||
"""get current user from access_token and validate same time""" | ||
if not super_client: | ||
raise HTTPException(status_code=500, detail="Super client not initialized") | ||
|
||
user_rsp = await super_client.auth.get_user(jwt=access_token) | ||
if not user_rsp: | ||
logging.error("User not found") | ||
raise HTTPException(status_code=404, detail="User not found") | ||
return UserIn(**user_rsp.user.model_dump(), access_token=access_token) | ||
|
||
|
||
CurrentUser = Annotated[UserIn, Depends(get_current_user)] | ||
|
||
|
||
async def get_db(user: CurrentUser) -> AsyncClient: | ||
client: AsyncClient | None = None | ||
try: | ||
client = await create_client( | ||
settings.SUPABASE_URL, | ||
settings.SUPABASE_KEY, | ||
access_token=user.access_token, | ||
options=ClientOptions( | ||
postgrest_client_timeout=10, storage_client_timeout=10 | ||
), | ||
) | ||
# checks all done in supabase-py ! | ||
# await client.auth.set_session(token.access_token, token.refresh_token) | ||
# session = await client.auth.get_session() | ||
yield client | ||
|
||
except AuthApiError as e: | ||
logging.error(e) | ||
raise HTTPException( | ||
status_code=401, detail="Invalid authentication credentials" | ||
) | ||
finally: | ||
if client: | ||
await client.auth.sign_out() | ||
|
||
|
||
SessionDep = Annotated[AsyncClient, Depends(get_db)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from fastapi import APIRouter, Body | ||
|
||
from ..common import task_queue | ||
from ..schemas import Face2Search, Face2SearchSchema | ||
from ..services.inference.common import TaskType | ||
|
||
auth_router = APIRouter(prefix="/auth", tags=["auth"]) | ||
|
||
# TODO: add face passport register | ||
# TODO: how to solve distribute results? | ||
# TODO: register face with id and name use sessionDepend | ||
@auth_router.post("/face-register/{id}/{name}") | ||
async def face_register(id: str, name: str, face: Face2SearchSchema = Body(...)) -> str: | ||
""" | ||
register face with id and name | ||
""" | ||
# resp = res | ||
to_register = Face2Search.from_schema(face).to_face() | ||
to_register.sign_up_id = id[:10] | ||
to_register.sign_up_name = name[:10] | ||
|
||
await task_queue.put_async((TaskType.REGISTER, to_register)) | ||
|
||
return "face_register successfully!" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import asyncio | ||
from queue import Empty | ||
|
||
from fastapi import APIRouter | ||
from gotrue import Session | ||
from starlette.websockets import WebSocketDisconnect | ||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK | ||
|
||
from ..common import result_queue, task_queue | ||
from ..core import WebSocketConnection, websocket_endpoint | ||
from ..core.config import logger | ||
from ..schemas import Face2Search, Face2SearchSchema, IdentifyResult, SystemStats | ||
from ..services.db.base_model import MatchedResult | ||
from ..services.inference.common import TaskType | ||
|
||
|
||
identify_router = APIRouter(prefix="/identify", tags=["identify"]) | ||
|
||
|
||
@identify_router.websocket("/ws/") | ||
@websocket_endpoint() | ||
async def identify_ws(connection: WebSocketConnection): | ||
while True: | ||
# test identifyResult | ||
try: | ||
rec_data = await connection.receive_data(Face2SearchSchema) | ||
logger.debug("rec_data:", rec_data) | ||
search_data = Face2Search.from_schema(rec_data) | ||
logger.debug(f"get the search data:{search_data}") | ||
|
||
await task_queue.put_async((TaskType.IDENTIFY, search_data.to_face())) | ||
|
||
try: | ||
res: MatchedResult = await result_queue.get_async() | ||
result = IdentifyResult.from_matched_result(res) | ||
await connection.send_data(result) | ||
except Empty: | ||
logger.warn("empty in result queue") | ||
|
||
# time_now = datetime.datetime.now() | ||
# result = IdentifyResult( | ||
# id=str(uuid.uuid4()), | ||
# name=session.user.user_metadata.get("name"), | ||
# time=time_now.strftime("%Y-%m-%d %H:%M:%S"), | ||
# uid=search_data.uid, | ||
# score=0.99 | ||
# ) | ||
|
||
# await asyncio.sleep(1) # 示例延时 | ||
except ( | ||
ConnectionClosedOK, | ||
ConnectionClosedError, | ||
RuntimeError, | ||
WebSocketDisconnect, | ||
) as e: | ||
logger.info(f"WebSocket error occurred: {e.__class__.__name__} - {e}") | ||
logger.info(f"Client left the chat") | ||
break | ||
|
||
|
||
@identify_router.websocket("/test/ws/{client_id}") | ||
@websocket_endpoint() | ||
async def test_connect(connection: WebSocketConnection, session: Session): | ||
"""cloud_system_monitor websocket""" | ||
while True: | ||
try: | ||
data = await connection.receive_data() | ||
logger.debug(f"test websocket receive data:{data}") | ||
await connection.send_data(data) | ||
logger.debug(f"test websocket send data:{data}") | ||
except ( | ||
ConnectionClosedOK, | ||
ConnectionClosedError, | ||
RuntimeError, | ||
WebSocketDisconnect, | ||
) as e: | ||
logger.info(f"occurred error {e} Client {session.user.id} left the chat") | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .types import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import asyncio | ||
import logging | ||
import multiprocessing | ||
from collections.abc import Callable | ||
from multiprocessing.queues import Queue | ||
from queue import Empty, Full | ||
|
||
|
||
class AsyncProcessQueue(Queue): | ||
def __init__(self, maxsize=1000): | ||
ctx = multiprocessing.get_context() | ||
super().__init__(maxsize, ctx=ctx) | ||
|
||
async def put_async(self, item): | ||
return await self._continued_try(self.put_nowait, item) | ||
|
||
async def get_async(self): | ||
return await self._continued_try(self.get_nowait) | ||
|
||
async def _continued_try(self, operation: Callable, *args): | ||
while True: | ||
try: | ||
return operation(*args) | ||
except Full: | ||
logging.debug("Queue is full") | ||
await asyncio.sleep(0.01) | ||
except Empty: | ||
logging.debug("Queue is empty") | ||
await asyncio.sleep(0.01) | ||
|
||
|
||
task_queue = AsyncProcessQueue() # Queue[tuple[TaskType, Face] | ||
result_queue = AsyncProcessQueue() # Queue[MatchedResult] | ||
registered_queue = AsyncProcessQueue() # Queue[str] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .events import lifespan | ||
from .web_socket_manager import ( | ||
WebSocketConnection, | ||
web_socket_manager, | ||
websocket_endpoint, | ||
) |
Oops, something went wrong.