Skip to content

Commit

Permalink
Merge pull request #23 from ultrasev/concurrent-recording
Browse files Browse the repository at this point in the history
fix(src.local_deploy): support concurrent recording
  • Loading branch information
ultrasev authored Apr 18, 2024
2 parents 0338325 + 67cbbd6 commit e8bf539
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 28 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/docker-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Build and Publish Docker image

on:
push:
branches:
- master

jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- name: Log in to GitHub Container Registry
uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CR_PAT }}

- name: Build and push Docker image
uses: docker/build-push-action@v2
with:
context: .
file: ./Dockerfile
push: true
tags: ghcr.io/${{ github.repository_owner }}/whisper:latest
11 changes: 11 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.8-slim
WORKDIR /app/
COPY requirements.txt /app/

RUN apt update && apt install -y libpq-dev gcc portaudio19-dev
RUN pip3 install -r requirements.txt
RUN pip3 install uvicorn fastapi pydantic python-multipart loguru==0.7.0

COPY ./src/docker/whisper.py /app/

CMD ["uvicorn", "whisper:app", "--host", "0.0.0.0", "--port", "8000"]
103 changes: 103 additions & 0 deletions src/docker/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/env python
import asyncio
import os
import typing
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO

import av
from fastapi import FastAPI, File, HTTPException, UploadFile
from faster_whisper import WhisperModel
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse

model_size = os.getenv('MODEL', 'base')


class ValidateFileTypeMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method.lower() == "post":
try:
response = await call_next(request)
return response
except av.error.InvalidDataError:
return JSONResponse(status_code=400,
content={"message": "Invalid file type"})
except Exception as e:
return JSONResponse(status_code=500,
content={"message": str(e)})


app = FastAPI()
app.add_middleware(ValidateFileTypeMiddleware)


async def asyncformer(sync_func: typing.Callable, *args, **kwargs):
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, sync_func, *args, **kwargs)


class Transcriber:
_instance = None

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(Transcriber, cls).__new__(cls)
# Put any initialization here.
return cls._instance

def __init__(
self,
model_size: str,
device: str = "auto",
compute_type: str = "default",
prompt: str = '实时/低延迟语音转写服务,林黛玉、倒拔、杨柳树、鲁迅、周树人、关键词、转写正确') -> None:
""" FasterWhisper 语音转写
Args:
model_size (str): 模型大小,可选项为 "tiny", "base", "small", "medium", "large" 。
更多信息参考:https://github.com/openai/whisper
device (str, optional): 模型运行设备。
compute_type (str, optional): 计算类型。默认为"default"。
prompt (str, optional): 初始提示。如果需要转写简体中文,可以使用简体中文提示。
"""
super().__init__()
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self.prompt = prompt

def __enter__(self) -> 'Transcriber':
self._model = WhisperModel(self.model_size,
device=self.device,
compute_type=self.compute_type)
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

async def __call__(self, audio: bytes) -> typing.AsyncGenerator[str, None]:
def _process():
return self._model.transcribe(BytesIO(audio),
initial_prompt=self.prompt,
vad_filter=True)

segments, info = await asyncformer(_process)
for segment in segments:
t = segment.text
if self.prompt in t.strip():
continue
if t.strip().replace('.', ''):
logger.info(t)
yield t


@app.post("/v1/audio/transcriptions")
async def _transcribe(file: UploadFile = File(...)):
with Transcriber(model_size) as stt:
audio = await file.read()
text = ','.join([seg async for seg in stt(audio)])
return {"text": text}
97 changes: 69 additions & 28 deletions src/local_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,36 @@
python3 local_deploy.py
"""

from faster_whisper import WhisperModel
from io import BytesIO
import typing
import io
import collections
import io
import logging
import queue
import threading
import typing
import wave
from io import BytesIO

import codefast as cf
import pyaudio
import webrtcvad
import logging
from faster_whisper import WhisperModel

logging.basicConfig(
level=logging.INFO,
format='%(name)s - %(levelname)s - %(message)s')
logging.basicConfig(level=logging.INFO,
format='%(name)s - %(levelname)s - %(message)s')


class Transcriber(object):
def __init__(self,
model_size: str,
device: str = "auto",
compute_type: str = "default",
prompt: str = '实时/低延迟语音转写服务,林黛玉、倒拔、杨柳树、鲁迅、周树人、关键词、转写正确'
) -> None:
class Queues:
audio = queue.Queue()
text = queue.Queue()


class Transcriber(threading.Thread):
def __init__(
self,
model_size: str,
device: str = "auto",
compute_type: str = "default",
prompt: str = '实时/低延迟语音转写服务,林黛玉、倒拔、杨柳树、鲁迅、周树人、关键词、转写正确') -> None:
""" FasterWhisper 语音转写
Args:
Expand All @@ -41,46 +48,58 @@ def __init__(self,
compute_type (str, optional): 计算类型。默认为"default"。
prompt (str, optional): 初始提示。如果需要转写简体中文,可以使用简体中文提示。
"""

super().__init__()
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self.prompt = prompt

def __enter__(self) -> 'Transcriber':
self._model = WhisperModel(self.model_size,
device=self.device,
compute_type=self.compute_type)
device=self.device,
compute_type=self.compute_type)
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
pass

def __call__(self, audio: bytes) -> typing.Generator[str, None, None]:
segments, info = self._model.transcribe(BytesIO(audio),
initial_prompt=self.prompt)
if info.language != "zh":
return {"error": "transcribe Chinese only"}
initial_prompt=self.prompt,
vad_filter=True)
# if info.language != "zh":
# return {"error": "transcribe Chinese only"}
for segment in segments:
t = segment.text
if self.prompt in t.strip():
continue
if t.strip().replace('.', ''):
yield t

def run(self):
while True:
audio = Queues.audio.get()
text = ''
for seg in self(audio):
logging.info(cf.fp.cyan(seg))
text += seg
Queues.text.put(text)


class AudioRecorder(object):
class AudioRecorder(threading.Thread):
""" Audio recorder.
Args:
channels (int, 可选): 通道数,默认为1(单声道)。
rate (int, 可选): 采样率,默认为16000 Hz。
chunk (int, 可选): 缓冲区中的帧数,默认为256。
frame_duration (int, 可选): 每帧的持续时间(单位:毫秒),默认为30。
"""

def __init__(self,
channels: int = 1,
sample_rate: int = 16000,
chunk: int = 256,
frame_duration: int = 30) -> None:
super().__init__()
self.sample_rate = sample_rate
self.channels = channels
self.chunk = chunk
Expand Down Expand Up @@ -116,7 +135,7 @@ def __bytes__(self) -> bytes:
self.__frames.clear()
return buf.getvalue()

def __iter__(self):
def run(self):
""" Record audio until silence is detected.
"""
MAXLEN = 30
Expand All @@ -139,16 +158,38 @@ def __iter__(self):
if num_unvoiced > ratio * watcher.maxlen:
logging.info("stop recording...")
triggered = False
yield bytes(self)
Queues.audio.put(bytes(self))
logging.info("audio task number: {}".format(
Queues.audio.qsize()))


class Chat(threading.Thread):
def __init__(self, prompt: str) -> None:
super().__init__()
self.prompt = prompt

def run(self):
prompt = "Hey! I'm currently working on my English speaking skills and I was hoping you could help me out. If you notice any mistakes in my expressions or if something I say doesn't sound quite right, could you please correct me? And if everything's fine, just carry on with a normal conversation. I'd really appreciate it if you could reply in a conversational, spoken English style. This way, it feels more like a natural chat. Thanks a lot for your help!"
while True:
text = Queues.text.get()
if text:
import os
os.system('chat "{}"'.format(prompt + text))
prompt = ""


def main():
try:
with AudioRecorder(channels=1, sample_rate=16000) as recorder:
with Transcriber(model_size="base") as transcriber:
for audio in recorder:
for seg in transcriber(audio):
logging.info(seg)
recorder.start()
transcriber.start()
# chat = Chat("")
# chat.start()

recorder.join()
transcriber.join()

except KeyboardInterrupt:
print("KeyboardInterrupt: terminating...")
except Exception as e:
Expand Down

0 comments on commit e8bf539

Please sign in to comment.