Skip to content

Commit

Permalink
refactor: use State to keep multithreading within controller
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 6, 2024
1 parent 5db5ec3 commit 4f2b300
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 59 deletions.
32 changes: 32 additions & 0 deletions capgen/transcriber/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import BinaryIO, Protocol


class TranscriberProtocol(Protocol):
"""
Summary
-------
a protocol for all transcriber(s)
Methods
-------
transcribe(file: str | BinaryIO, caption_format: str) -> str | None:
converts transcription segments into a specific caption format
"""

__slots__ = ('model',)

def transcribe(self, file: str | BinaryIO, caption_format: str) -> str | None:
"""
Summary
-------
transcribes a compatible audio/video into a chosen caption format
Parameters
----------
file (str | BinaryIO) : the file to transcribe
caption_format (str) : the chosen caption format
Returns
-------
transcription (str | None) : the transcribed text in the chosen caption format
"""
3 changes: 2 additions & 1 deletion capgen/transcriber/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from faster_whisper import WhisperModel

from capgen.transcriber.converter import Converter
from capgen.transcriber.protocol import TranscriberProtocol


class WhisperParameters(TypedDict):
Expand All @@ -19,7 +20,7 @@ class WhisperParameters(TypedDict):
num_workers: int


class Transcriber:
class Transcriber(TranscriberProtocol):
"""
Summary
-------
Expand Down
13 changes: 10 additions & 3 deletions server/api/v1/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from asyncio import wrap_future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Annotated, Literal

Expand All @@ -8,8 +10,8 @@
from litestar.params import Body
from litestar.status_codes import HTTP_200_OK

from server.features import Transcriber
from server.schemas.v1 import Transcribed
from server.state import AppState


class TranscriberController(Controller):
Expand All @@ -20,10 +22,12 @@ class TranscriberController(Controller):
"""

path = '/transcribe'
thread_pool = ThreadPoolExecutor()

@post(status_code=HTTP_200_OK)
async def transcribe(
self,
state: AppState,
data: Annotated[UploadFile, Body(media_type=RequestEncodingType.MULTI_PART)],
caption_format: Literal['srt', 'vtt'] = 'srt',
) -> Transcribed:
Expand All @@ -32,7 +36,10 @@ async def transcribe(
-------
the POST variant of the `/transcribe` route
"""
if not (result := await Transcriber.transcribe(BytesIO(await data.read()), caption_format)):
audio = BytesIO(await data.read())
transcription = await wrap_future(self.thread_pool.submit(state.transcriber.transcribe, audio, caption_format))

if not transcription:
raise ClientException(detail=f'Invalid format: {caption_format}!')

return Transcribed(result=result)
return Transcribed(result=transcription)
1 change: 0 additions & 1 deletion server/features/__init__.py

This file was deleted.

50 changes: 0 additions & 50 deletions server/features/transcriber.py

This file was deleted.

15 changes: 11 additions & 4 deletions server/lifespans/load_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator

from server.features import Transcriber
from litestar import Litestar

from capgen.transcriber import Transcriber
from server.config import Config


@asynccontextmanager
async def load_model(_) -> AsyncIterator[None]:
async def load_model(app: Litestar) -> AsyncIterator[None]:
"""
Summary
-------
download and load the model
"""
Transcriber.load()
yield
app.state.transcriber = Transcriber('cpu', number_of_workers=Config.worker_count)

try:
yield
finally:
pass
13 changes: 13 additions & 0 deletions server/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from litestar.datastructures import State

from capgen.transcriber.protocol import TranscriberProtocol


class AppState(State):
"""
Summary
-------
the Litestar application state that will be injected into the routers
"""

transcriber: TranscriberProtocol

0 comments on commit 4f2b300

Please sign in to comment.