diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c137306a..36b13554 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -250,6 +250,7 @@ def read(self) -> Optional[OAuthToken]: self.telemetry_enabled = ( self.client_telemetry_enabled and self.server_telemetry_enabled ) + self.telemetry_batch_size = kwargs.get("telemetry_batch_size") user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: @@ -311,6 +312,7 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex(), auth_provider=auth_provider, host_url=self.host, + batch_size=self.telemetry_batch_size, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 10aa04ef..a90a4bd6 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -116,6 +116,10 @@ def export_failure_log(self, error_name, error_message): def close(self): raise NotImplementedError("Subclasses must implement close") + @abstractmethod + def flush(self): + raise NotImplementedError("Subclasses must implement flush") + class NoopTelemetryClient(BaseTelemetryClient): """ @@ -139,6 +143,9 @@ def export_failure_log(self, error_name, error_message): def close(self): pass + def flush(self): + pass + class TelemetryClient(BaseTelemetryClient): """ @@ -150,6 +157,8 @@ class TelemetryClient(BaseTelemetryClient): TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" + DEFAULT_BATCH_SIZE = 100 + def __init__( self, telemetry_enabled, @@ -157,10 +166,13 @@ def __init__( auth_provider, host_url, executor, + batch_size=None, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = 10 # TODO: Decide on batch size + self._batch_size = ( + batch_size if batch_size is not None else self.DEFAULT_BATCH_SIZE + ) self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -172,6 +184,7 @@ def __init__( def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" + logger.debug("Exporting event for connection %s", self._session_id_hex) with self._lock: self._events_batch.append(event) @@ -179,10 +192,11 @@ def _export_event(self, event): logger.debug( "Batch size limit reached (%s), flushing events", self._batch_size ) - self._flush() + self.flush() - def _flush(self): + def flush(self): """Flush the current batch of events to the server""" + with self._lock: events_to_flush = self._events_batch.copy() self._events_batch = [] @@ -302,13 +316,13 @@ def export_failure_log(self, error_name, error_message): def close(self): """Flush remaining events before closing""" logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) - self._flush() + self.flush() class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. - It uses a thread pool to handle asynchronous operations. + It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ _clients: Dict[ @@ -321,6 +335,11 @@ class TelemetryClientFactory: _original_excepthook = None _excepthook_installed = False + # Shared flush thread for all clients + _flush_thread = None + _flush_event = threading.Event() + _flush_interval_seconds = 90 + @classmethod def _initialize(cls): """Initialize the factory if not already initialized""" @@ -331,11 +350,42 @@ def _initialize(cls): max_workers=10 ) # Thread pool for async operations TODO: Decide on max workers cls._install_exception_hook() + cls._start_flush_thread() cls._initialized = True logger.debug( - "TelemetryClientFactory initialized with thread pool (max_workers=10)" + "TelemetryClientFactory initialized with thread pool (max_workers=10) and shared flush thread" ) + @classmethod + def _start_flush_thread(cls): + """Start the shared background thread for periodic flushing of all clients""" + cls._flush_event.clear() + cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True) + cls._flush_thread.start() + + @classmethod + def _flush_worker(cls): + """Background worker thread for periodic flushing of all clients""" + while not cls._flush_event.wait(cls._flush_interval_seconds): + logger.debug("Performing periodic flush for all telemetry clients") + + with cls._lock: + clients_to_flush = list(cls._clients.values()) + + for client in clients_to_flush: + try: + client.flush() + except Exception as e: + logger.debug("Failed to flush telemetry client: %s", e) + + @classmethod + def _stop_flush_thread(cls): + """Stop the shared background flush thread""" + if cls._flush_thread is not None: + cls._flush_event.set() + cls._flush_thread.join(timeout=1.0) + cls._flush_thread = None + @classmethod def _install_exception_hook(cls): """Install global exception handler for unhandled exceptions""" @@ -364,6 +414,7 @@ def initialize_telemetry_client( session_id_hex, auth_provider, host_url, + batch_size=None, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -385,6 +436,7 @@ def initialize_telemetry_client( auth_provider=auth_provider, host_url=host_url, executor=TelemetryClientFactory._executor, + batch_size=batch_size, ) else: TelemetryClientFactory._clients[ @@ -426,11 +478,12 @@ def close(session_id_hex): ) telemetry_client.close() - # Shutdown executor if no more clients + # Shutdown executor and flush thread if no more clients if not TelemetryClientFactory._clients and TelemetryClientFactory._executor: logger.debug( - "No more telemetry clients, shutting down thread pool executor" + "No more telemetry clients, shutting down thread pool executor and flush thread" ) + TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None TelemetryClientFactory._initialized = False diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 699480bb..418d3927 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -181,22 +181,22 @@ def test_export_failure_log( client._export_event.assert_called_once_with(mock_frontend_log.return_value) - def test_export_event(self, telemetry_client_setup): - """Test exporting an event.""" + def test_batch_size_flush(self, telemetry_client_setup): + """Test batch size flush.""" client = telemetry_client_setup["client"] - client._flush = MagicMock() + client.flush = MagicMock() - for i in range(5): + batch_size = client._batch_size + + for i in range(batch_size - 1): client._export_event(f"event-{i}") - client._flush.assert_not_called() - assert len(client._events_batch) == 5 + client.flush.assert_not_called() + assert len(client._events_batch) == batch_size - 1 - for i in range(5, 10): - client._export_event(f"event-{i}") + client._export_event(f"event-{batch_size - 1}") - client._flush.assert_called_once() - assert len(client._events_batch) == 10 + client.flush.assert_called_once() @patch("requests.post") def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup): @@ -251,7 +251,7 @@ def test_flush(self, telemetry_client_setup): client._events_batch = ["event1", "event2"] client._send_telemetry = MagicMock() - client._flush() + client.flush() client._send_telemetry.assert_called_once_with(["event1", "event2"]) assert client._events_batch == [] @@ -259,11 +259,11 @@ def test_flush(self, telemetry_client_setup): def test_close(self, telemetry_client_setup): """Test closing the client.""" client = telemetry_client_setup["client"] - client._flush = MagicMock() + client.flush = MagicMock() client.close() - client._flush.assert_called_once() + client.flush.assert_called_once() @patch("requests.post") def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup):