Skip to content

Commit

Permalink
Retry voice handling when the endpoint returns 429 (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludeeus authored Oct 13, 2023
1 parent 4b59378 commit 648d07b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
26 changes: 24 additions & 2 deletions hass_nabucasa/voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,12 +1250,13 @@ async def process_stt(
stream: AsyncIterable[bytes],
content_type: str,
language: str,
force_token_renewal: bool = False,
) -> STTResponse:
"""Stream Audio to Azure congnitive instance."""
if language not in STT_LANGUAGES:
raise VoiceError(f"Language {language} not supported")

if not self._validate_token():
if force_token_renewal or not self._validate_token():
await self._update_token()

# Send request
Expand All @@ -1271,6 +1272,15 @@ async def process_stt(
expect100=True,
chunked=True,
) as resp:
if resp.status == 429 and not force_token_renewal:
# By checking the force_token_renewal argument, we limit retries to 1.
_LOGGER.info("Retrying with new token")
return await self.process_stt(
stream=stream,
content_type=content_type,
language=language,
force_token_renewal=True,
)
if resp.status not in (200, 201):
raise VoiceReturnError(
f"Error processing {language} speech: {resp.status} {await resp.text()}"
Expand All @@ -1290,6 +1300,7 @@ async def process_tts(
output: AudioOutput,
voice: str | None = None,
gender: Gender | None = None,
force_token_renewal: bool = False,
) -> bytes:
"""Get Speech from text over Azure."""
if language not in TTS_VOICES:
Expand All @@ -1306,7 +1317,7 @@ async def process_tts(
if voice not in TTS_VOICES[language]:
raise VoiceError(f"Unsupported voice {voice} for language {language}")

if not self._validate_token():
if force_token_renewal or not self._validate_token():
await self._update_token()

# SSML
Expand Down Expand Up @@ -1339,6 +1350,17 @@ async def process_tts(
},
data=ET.tostring(xml_body),
) as resp:
if resp.status == 429 and not force_token_renewal:
# By checking the force_token_renewal argument, we limit retries to 1.
_LOGGER.info("Retrying with new token")
return await self.process_tts(
text=text,
language=language,
output=output,
voice=voice,
gender=gender,
force_token_renewal=True,
)
if resp.status not in (200, 201):
raise VoiceReturnError(
f"Error receiving TTS with {language}/{voice}: {resp.status} {await resp.text()}"
Expand Down
43 changes: 43 additions & 0 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,46 @@ async def test_process_tts_bad_voice(voice_api):
voice="Not a US voice",
output=voice.AudioOutput.MP3,
)


async def test_process_tss_429(
voice_api, mock_voice_connection_details, aioclient_mock, caplog
):
"""Test handling of voice with 429."""
aioclient_mock.post(
"tts-url",
status=429,
)

with pytest.raises(voice.VoiceError):
await voice_api.process_tts(
text="Text for Saying",
language="en-US",
gender=voice.Gender.FEMALE,
output=voice.AudioOutput.MP3,
)

assert len(aioclient_mock.mock_calls) == 4

assert "Retrying with new token" in caplog.text


async def test_process_stt_429(
voice_api, mock_voice_connection_details, aioclient_mock, caplog
):
"""Test handling of voice with 429."""
aioclient_mock.post(
"stt-url",
status=429,
)

with pytest.raises(voice.VoiceError):
await voice_api.process_stt(
stream=b"feet",
content_type="video=test",
language="en-US",
)

assert len(aioclient_mock.mock_calls) == 4

assert "Retrying with new token" in caplog.text

0 comments on commit 648d07b

Please sign in to comment.