From 648d07b29ba8f70b14377f385a1ae4ca2fc27502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20S=C3=B8rensen?= Date: Fri, 13 Oct 2023 09:12:45 +0200 Subject: [PATCH] Retry voice handling when the endpoint returns 429 (#519) --- hass_nabucasa/voice.py | 26 +++++++++++++++++++++++-- tests/test_voice.py | 43 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/hass_nabucasa/voice.py b/hass_nabucasa/voice.py index 169adab9a..7de2ab213 100644 --- a/hass_nabucasa/voice.py +++ b/hass_nabucasa/voice.py @@ -1250,12 +1250,13 @@ async def process_stt( stream: AsyncIterable[bytes], content_type: str, language: str, + force_token_renewal: bool = False, ) -> STTResponse: """Stream Audio to Azure congnitive instance.""" if language not in STT_LANGUAGES: raise VoiceError(f"Language {language} not supported") - if not self._validate_token(): + if force_token_renewal or not self._validate_token(): await self._update_token() # Send request @@ -1271,6 +1272,15 @@ async def process_stt( expect100=True, chunked=True, ) as resp: + if resp.status == 429 and not force_token_renewal: + # By checking the force_token_renewal argument, we limit retries to 1. + _LOGGER.info("Retrying with new token") + return await self.process_stt( + stream=stream, + content_type=content_type, + language=language, + force_token_renewal=True, + ) if resp.status not in (200, 201): raise VoiceReturnError( f"Error processing {language} speech: {resp.status} {await resp.text()}" @@ -1290,6 +1300,7 @@ async def process_tts( output: AudioOutput, voice: str | None = None, gender: Gender | None = None, + force_token_renewal: bool = False, ) -> bytes: """Get Speech from text over Azure.""" if language not in TTS_VOICES: @@ -1306,7 +1317,7 @@ async def process_tts( if voice not in TTS_VOICES[language]: raise VoiceError(f"Unsupported voice {voice} for language {language}") - if not self._validate_token(): + if force_token_renewal or not self._validate_token(): await self._update_token() # SSML @@ -1339,6 +1350,17 @@ async def process_tts( }, data=ET.tostring(xml_body), ) as resp: + if resp.status == 429 and not force_token_renewal: + # By checking the force_token_renewal argument, we limit retries to 1. + _LOGGER.info("Retrying with new token") + return await self.process_tts( + text=text, + language=language, + output=output, + voice=voice, + gender=gender, + force_token_renewal=True, + ) if resp.status not in (200, 201): raise VoiceReturnError( f"Error receiving TTS with {language}/{voice}: {resp.status} {await resp.text()}" diff --git a/tests/test_voice.py b/tests/test_voice.py index 0597edb24..b3931fcee 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -140,3 +140,46 @@ async def test_process_tts_bad_voice(voice_api): voice="Not a US voice", output=voice.AudioOutput.MP3, ) + + +async def test_process_tss_429( + voice_api, mock_voice_connection_details, aioclient_mock, caplog +): + """Test handling of voice with 429.""" + aioclient_mock.post( + "tts-url", + status=429, + ) + + with pytest.raises(voice.VoiceError): + await voice_api.process_tts( + text="Text for Saying", + language="en-US", + gender=voice.Gender.FEMALE, + output=voice.AudioOutput.MP3, + ) + + assert len(aioclient_mock.mock_calls) == 4 + + assert "Retrying with new token" in caplog.text + + +async def test_process_stt_429( + voice_api, mock_voice_connection_details, aioclient_mock, caplog +): + """Test handling of voice with 429.""" + aioclient_mock.post( + "stt-url", + status=429, + ) + + with pytest.raises(voice.VoiceError): + await voice_api.process_stt( + stream=b"feet", + content_type="video=test", + language="en-US", + ) + + assert len(aioclient_mock.mock_calls) == 4 + + assert "Retrying with new token" in caplog.text