Skip to content

Commit

Permalink
add generate_from_phonemes()
Browse files Browse the repository at this point in the history
  • Loading branch information
eschmidbauer committed Jan 14, 2025
1 parent 5fc3696 commit 20952d4
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 64 deletions.
48 changes: 17 additions & 31 deletions api/src/routers/development.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
Expand All @@ -8,11 +6,8 @@
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
)
from ..structures.text_schemas import (GenerateFromPhonemesRequest,
PhonemeRequest, PhonemeResponse)

router = APIRouter(tags=["text processing"])

Expand Down Expand Up @@ -64,7 +59,7 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
@router.post("/dev/generate_from_phonemes")
async def generate_from_phonemes(
request: GenerateFromPhonemesRequest,
tts_service: TTSService = Depends(get_tts_service),
tts_service: TTSService = Depends(get_tts_service)
) -> Response:
"""Generate audio directly from phonemes
Expand All @@ -76,55 +71,46 @@ async def generate_from_phonemes(
WAV audio bytes
"""
# Validate phonemes first
if not request.phonemes:
if not request.phonemes or len(request.phonemes) == 0:
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"}
)

# Validate voice exists
voice_path = tts_service._get_voice_path(request.voice)
if not voice_path:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
detail={"error": "Invalid request", "message": f"Voice not found: {request.voice}"}
)

try:
# Load voice
voicepack = tts_service._load_voice(voice_path)

# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens

# Generate audio directly from tokens
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)

# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
)

trim_samples: int = 0
if request.trim_ms > 0:
trim_samples = int((request.trim_ms / 1000) * 24000)
pause_duration = np.zeros(int(24000 * request.pause_duration), dtype=np.float32)
audio: np.ndarray = TTSModel.generate_from_phonemes(phonemes=request.phonemes, voicepack=voicepack, speed=request.speed, trim_samples=trim_samples, pause_duration=pause_duration)
wav_bytes = AudioService.convert_audio(audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False)
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=speech.wav",
"Cache-Control": "no-cache",
},
}
)

except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400, detail={"error": "Invalid request", "message": str(e)}
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)}
status_code=500,
detail={"error": "Server error", "message": str(e)}
)
18 changes: 18 additions & 0 deletions api/src/services/tts_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
"""
pass

@classmethod
@abstractmethod
def generate_from_phonemes(cls, phonemes: List[str], voicepack: torch.Tensor, speed: float, trim_samples: int, pause_duration: np.ndarray) -> np.ndarray:
"""Generate audio from list of phonemes
Args:
phonemes: Input phonemes
voicepack: Voice tensor
language: Language code
speed: Speed factor
trim_samples: Trim samples from chunk
pause_duration: Pause between chunks
Returns:
List[np.ndarray]: Generated audio samples
"""
pass

@classmethod
@abstractmethod
def generate_from_text(
Expand Down
39 changes: 33 additions & 6 deletions api/src/services/tts_cpu.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os
from typing import List

import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from onnxruntime import (ExecutionMode, GraphOptimizationLevel,
InferenceSession, SessionOptions)

from ..core.config import settings
from .text_processing import phonemize, tokenize
Expand Down Expand Up @@ -108,6 +105,36 @@ def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
tokens = [0] + tokens + [0] # Add start/end tokens
return phonemes, tokens

@classmethod
def generate_from_phonemes(cls, phonemes: List[str], voicepack: torch.Tensor, speed: float, trim_samples: int, pause_duration: np.ndarray) -> np.ndarray:
"""Generate audio from list of phonemes
Args:
phonemes: Input phonemes
voicepack: Voice tensor
language: Language code
speed: Speed factor
trim_samples: Trim samples from chunk
pause_duration: Pause between chunks
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")

n: List[np.ndarray] = []
for i, phoneme in enumerate(phonemes):
tokens = tokenize(phoneme)
tokens = [0] + tokens + [0] # Add start/end tokens
audio = cls.generate_from_tokens(tokens, voicepack, speed)
if trim_samples > 0:
audio = audio[trim_samples:-trim_samples]
n.append(audio)
if i < len(n):
n.append(pause_duration)
return np.concatenate(n) if len(n) > 1 else n[0]

@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
Expand Down
59 changes: 44 additions & 15 deletions api/src/services/tts_gpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import time
from typing import List

import numpy as np
import torch
from ..builds.models import build_model
from loguru import logger

from ..builds.models import build_model
from ..core.config import settings
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
Expand Down Expand Up @@ -40,7 +40,7 @@
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model with moderate memory management"""
device = ref_s.device

try:
# Initial tensor setup with proper device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
Expand Down Expand Up @@ -70,14 +70,14 @@ def forward(model, tokens, ref_s, speed):
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
pred_aln_trg[i, c_frame: c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)

# Matrix multiplications with selective cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d # Free large intermediate tensor

F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en # Free large intermediate tensor

Expand All @@ -89,9 +89,9 @@ def forward(model, tokens, ref_s, speed):
# Final decoding and transfer to CPU
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()

return result

finally:
# Let PyTorch handle most cleanup automatically
# Only explicitly free the largest tensors
Expand Down Expand Up @@ -165,6 +165,35 @@ def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
tokens = tokenize(phonemes)
return phonemes, tokens

@classmethod
def generate_from_phonemes(cls, phonemes: List[str], voicepack: torch.Tensor, speed: float, trim_samples: int, pause_duration: np.ndarray) -> np.ndarray:
"""Generate audio from list of phonemes
Args:
phonemes: Input phonemes
voicepack: Voice tensor
language: Language code
speed: Speed factor
trim_samples: Trim samples from chunk
pause_duration: Pause between chunks
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")

n: List[np.ndarray] = []
for i, phoneme in enumerate(phonemes):
tokens = tokenize(phoneme)
audio = cls.generate_from_tokens(tokens, voicepack, speed)
if trim_samples > 0:
audio = audio[trim_samples:-trim_samples]
n.append(audio)
if i < len(n):
n.append(pause_duration)
return np.concatenate(n) if len(n) > 1 else n[0]

@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
Expand Down Expand Up @@ -210,7 +239,7 @@ def generate_from_tokens(

try:
device = cls._device

# Check memory pressure
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
Expand All @@ -222,15 +251,15 @@ def generate_from_tokens(
torch.cuda.empty_cache()
import gc
gc.collect()

# Get reference style with proper device placement
ref_s = voicepack[len(tokens)].clone().to(device)

# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)

return audio

except RuntimeError as e:
if "out of memory" in str(e):
# On OOM, do a full cleanup and retry
Expand All @@ -240,7 +269,7 @@ def generate_from_tokens(
torch.cuda.empty_cache()
import gc
gc.collect()

# Log memory stats after cleanup
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
Expand All @@ -249,13 +278,13 @@ def generate_from_tokens(
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
f"Reserved: {memory_reserved / 1e9:.2f}GB"
)

# Retry generation
ref_s = voicepack[len(tokens)].clone().to(device)
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
raise

finally:
# Only synchronize at the top level, no empty_cache
if torch.cuda.is_available():
Expand Down
9 changes: 5 additions & 4 deletions api/src/structures/text_schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from typing import List


class PhonemeRequest(BaseModel):
Expand All @@ -12,8 +13,8 @@ class PhonemeResponse(BaseModel):


class GenerateFromPhonemesRequest(BaseModel):
phonemes: str
phonemes: List[str]
voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
)
trim_ms: int = Field(default=0, ge=0, le=100000, description="Trim milliseconds of audio before adding pause")
pause_duration: float = Field(default=0.0, ge=0.0, le=60.0, description="Pause duration in seconds between sentences")
speed: float = Field(default=1.0, ge=0.1, le=5.0, description="Speed factor for generation")
8 changes: 4 additions & 4 deletions api/tests/test_text_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_generate_from_phonemes(
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": 1.0},
json={"phonemes": ["həlˈoʊ"], "voice": "af_bella", "trim_ms": 0, "pause_duration": 0.0, "speed": 1.0},
)

assert response.status_code == 200
Expand All @@ -80,7 +80,7 @@ async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_servi
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈoʊ", "voice": "invalid_voice", "speed": 1.0},
json={"phonemes": ["həlˈoʊ"], "voice": "invalid_voice", "trim_ms": 0, "pause_duration": 0.0, "speed": 1.0},
)

assert response.status_code == 400
Expand All @@ -101,7 +101,7 @@ async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):

response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": -1.0},
json={"phonemes": ["həlˈoʊ"], "voice": "af_bella", "trim_ms": 0, "pause_duration": 0.0, "speed": -1.0},
)

assert response.status_code == 422 # Validation error
Expand All @@ -115,7 +115,7 @@ async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_serv
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
json={"phonemes": [], "voice": "af_bella", "trim_ms": 0, "pause_duration": 0.0, "speed": 1.0},
)

assert response.status_code == 400
Expand Down
Loading

0 comments on commit 20952d4

Please sign in to comment.