Skip to content

Realtime: fix interrupt and audio tracking #1220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self) -> None:
self._ongoing_response: bool = False
self._current_audio_content_index: int | None = None
self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None
self._created_session: OpenAISessionObject | None = None

async def connect(self, options: RealtimeModelConfig) -> None:
"""Establish a connection to the model and keep it alive."""
Expand Down Expand Up @@ -298,10 +299,18 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
if not self._current_item_id or not self._audio_start_time:
return

await self._cancel_response()
automatic_response_cancellation_enabled = (
self._created_session
and self._created_session.turn_detection
and self._created_session.turn_detection.interrupt_response
)

if not automatic_response_cancellation_enabled:
await self._cancel_response()

elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:

if elapsed_time_ms > 0 and elapsed_time_ms <= self._audio_length_ms:
await self._emit_event(RealtimeModelAudioInterruptedEvent())
converted = _ConversionHelper.convert_interrupt(
self._current_item_id,
Expand Down Expand Up @@ -335,8 +344,16 @@ async def _handle_audio_delta(self, parsed: ResponseAudioDeltaEvent) -> None:
)

def _calculate_audio_length_ms(self, audio_bytes: bytes) -> float:
"""Calculate audio length in milliseconds for 24KHz PCM16LE format."""
return len(audio_bytes) / 24 / 2
audio_format = "pcm16"
if self._created_session and self._created_session.output_audio_format:
audio_format = self._created_session.output_audio_format

if audio_format.startswith("g711"):
# 8kHz * 1 byte per sample
return (len(audio_bytes) / 8000) * 1000 # Convert seconds to milliseconds
else:
# 24kHz * 2 bytes per sample
return (len(audio_bytes) / (24000 * 2)) * 1000 # Convert seconds to milliseconds

async def _handle_output_item(self, item: ConversationItem) -> None:
"""Handle response output item events (function calls and messages)."""
Expand Down Expand Up @@ -439,6 +456,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
self._ongoing_response = False
await self._emit_event(RealtimeModelTurnEndedEvent())
elif parsed.type == "session.created":
self._created_session = parsed.session
await self._send_tracing_config(self._tracing_config)
elif parsed.type == "error":
await self._emit_event(RealtimeModelErrorEvent(error=parsed.error))
Expand Down
27 changes: 20 additions & 7 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,30 @@ async def test_audio_timing_calculation_accuracy(self, model):
for event in audio_deltas:
await model._handle_ws_event(event)

# Should accumulate audio length: 8 bytes / 24 / 2 = ~0.167ms per byte
# Total: 8 bytes / 24 / 2 = 0.167ms
expected_length = 8 / 24 / 2
# Should accumulate audio length: 8 bytes / (24000 * 2) * 1000 = ~0.167ms total
expected_length = (8 / (24000 * 2)) * 1000
assert abs(model._audio_length_ms - expected_length) < 0.001

def test_calculate_audio_length_ms_pure_function(self, model):
"""Test the pure audio length calculation function."""
# Test various audio buffer sizes
assert model._calculate_audio_length_ms(b"test") == 4 / 24 / 2 # 4 bytes
# Test various audio buffer sizes for PCM16 (default)
assert model._calculate_audio_length_ms(b"test") == (4 / (24000 * 2)) * 1000 # 4 bytes
assert model._calculate_audio_length_ms(b"") == 0 # empty
assert model._calculate_audio_length_ms(b"a" * 48) == 1.0 # exactly 1ms worth
assert model._calculate_audio_length_ms(b"a" * 48000) == 1000.0 # exactly 1 second worth

def test_calculate_audio_length_ms_g711_format(self, model):
"""Test audio length calculation for G.711 format."""
from unittest.mock import Mock

# Mock session with g711 format
mock_session = Mock()
mock_session.output_audio_format = "g711_ulaw"
model._created_session = mock_session

# Test G.711: 8kHz * 1 byte per sample
assert model._calculate_audio_length_ms(b"test") == (4 / 8000) * 1000 # 4 bytes = 0.5ms
assert model._calculate_audio_length_ms(b"") == 0 # empty
assert model._calculate_audio_length_ms(b"a" * 8000) == 1000.0 # exactly 1 second worth

@pytest.mark.asyncio
async def test_handle_audio_delta_state_management(self, model):
Expand All @@ -376,4 +389,4 @@ async def test_handle_audio_delta_state_management(self, model):
assert model._current_audio_content_index == 5
assert model._current_item_id == "test_item"
assert model._audio_start_time == mock_now
assert model._audio_length_ms == 4 / 24 / 2 # 4 bytes
assert model._audio_length_ms == (4 / (24000 * 2)) * 1000 # 4 bytes, converted to ms
Loading