Skip to content

Make telemetry batch size configurable and add time-based flush #622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: telemetry
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def read(self) -> Optional[OAuthToken]:
self.telemetry_enabled = (
self.client_telemetry_enabled and self.server_telemetry_enabled
)
self.telemetry_batch_size = kwargs.get("telemetry_batch_size")

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
Expand Down Expand Up @@ -311,6 +312,7 @@ def read(self) -> Optional[OAuthToken]:
session_id_hex=self.get_session_id_hex(),
auth_provider=auth_provider,
host_url=self.host,
batch_size=self.telemetry_batch_size,
)

self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
Expand Down
69 changes: 61 additions & 8 deletions src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def export_failure_log(self, error_name, error_message):
def close(self):
raise NotImplementedError("Subclasses must implement close")

@abstractmethod
def flush(self):
raise NotImplementedError("Subclasses must implement flush")


class NoopTelemetryClient(BaseTelemetryClient):
"""
Expand All @@ -139,6 +143,9 @@ def export_failure_log(self, error_name, error_message):
def close(self):
pass

def flush(self):
pass


class TelemetryClient(BaseTelemetryClient):
"""
Expand All @@ -150,17 +157,22 @@ class TelemetryClient(BaseTelemetryClient):
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"

DEFAULT_BATCH_SIZE = 100

def __init__(
self,
telemetry_enabled,
session_id_hex,
auth_provider,
host_url,
executor,
batch_size=None,
):
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
self._telemetry_enabled = telemetry_enabled
self._batch_size = 10 # TODO: Decide on batch size
self._batch_size = (
batch_size if batch_size is not None else self.DEFAULT_BATCH_SIZE
)
self._session_id_hex = session_id_hex
self._auth_provider = auth_provider
self._user_agent = None
Expand All @@ -172,17 +184,19 @@ def __init__(

def _export_event(self, event):
"""Add an event to the batch queue and flush if batch is full"""

logger.debug("Exporting event for connection %s", self._session_id_hex)
with self._lock:
self._events_batch.append(event)
if len(self._events_batch) >= self._batch_size:
logger.debug(
"Batch size limit reached (%s), flushing events", self._batch_size
)
self._flush()
self.flush()

def _flush(self):
def flush(self):
"""Flush the current batch of events to the server"""

with self._lock:
events_to_flush = self._events_batch.copy()
self._events_batch = []
Expand Down Expand Up @@ -302,13 +316,13 @@ def export_failure_log(self, error_name, error_message):
def close(self):
"""Flush remaining events before closing"""
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)
self._flush()
self.flush()


class TelemetryClientFactory:
"""
Static factory class for creating and managing telemetry clients.
It uses a thread pool to handle asynchronous operations.
It uses a thread pool to handle asynchronous operations and a single flush thread for all clients.
"""

_clients: Dict[
Expand All @@ -321,6 +335,11 @@ class TelemetryClientFactory:
_original_excepthook = None
_excepthook_installed = False

# Shared flush thread for all clients
_flush_thread = None
_flush_event = threading.Event()
_flush_interval_seconds = 90

@classmethod
def _initialize(cls):
"""Initialize the factory if not already initialized"""
Expand All @@ -331,11 +350,42 @@ def _initialize(cls):
max_workers=10
) # Thread pool for async operations TODO: Decide on max workers
cls._install_exception_hook()
cls._start_flush_thread()
cls._initialized = True
logger.debug(
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
"TelemetryClientFactory initialized with thread pool (max_workers=10) and shared flush thread"
)

@classmethod
def _start_flush_thread(cls):
"""Start the shared background thread for periodic flushing of all clients"""
cls._flush_event.clear()
cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True)
cls._flush_thread.start()

@classmethod
def _flush_worker(cls):
"""Background worker thread for periodic flushing of all clients"""
while not cls._flush_event.wait(cls._flush_interval_seconds):
logger.debug("Performing periodic flush for all telemetry clients")

with cls._lock:
clients_to_flush = list(cls._clients.values())

for client in clients_to_flush:
try:
client.flush()
except Exception as e:
logger.debug("Failed to flush telemetry client: %s", e)

@classmethod
def _stop_flush_thread(cls):
"""Stop the shared background flush thread"""
if cls._flush_thread is not None:
cls._flush_event.set()
cls._flush_thread.join(timeout=1.0)
cls._flush_thread = None

@classmethod
def _install_exception_hook(cls):
"""Install global exception handler for unhandled exceptions"""
Expand Down Expand Up @@ -364,6 +414,7 @@ def initialize_telemetry_client(
session_id_hex,
auth_provider,
host_url,
batch_size=None,
):
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
try:
Expand All @@ -385,6 +436,7 @@ def initialize_telemetry_client(
auth_provider=auth_provider,
host_url=host_url,
executor=TelemetryClientFactory._executor,
batch_size=batch_size,
)
else:
TelemetryClientFactory._clients[
Expand Down Expand Up @@ -426,11 +478,12 @@ def close(session_id_hex):
)
telemetry_client.close()

# Shutdown executor if no more clients
# Shutdown executor and flush thread if no more clients
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
logger.debug(
"No more telemetry clients, shutting down thread pool executor"
"No more telemetry clients, shutting down thread pool executor and flush thread"
)
TelemetryClientFactory._stop_flush_thread()
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False
26 changes: 13 additions & 13 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,22 @@ def test_export_failure_log(

client._export_event.assert_called_once_with(mock_frontend_log.return_value)

def test_export_event(self, telemetry_client_setup):
"""Test exporting an event."""
def test_batch_size_flush(self, telemetry_client_setup):
"""Test batch size flush."""
client = telemetry_client_setup["client"]
client._flush = MagicMock()
client.flush = MagicMock()

for i in range(5):
batch_size = client._batch_size

for i in range(batch_size - 1):
client._export_event(f"event-{i}")

client._flush.assert_not_called()
assert len(client._events_batch) == 5
client.flush.assert_not_called()
assert len(client._events_batch) == batch_size - 1

for i in range(5, 10):
client._export_event(f"event-{i}")
client._export_event(f"event-{batch_size - 1}")

client._flush.assert_called_once()
assert len(client._events_batch) == 10
client.flush.assert_called_once()

@patch("requests.post")
def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
Expand Down Expand Up @@ -251,19 +251,19 @@ def test_flush(self, telemetry_client_setup):
client._events_batch = ["event1", "event2"]
client._send_telemetry = MagicMock()

client._flush()
client.flush()

client._send_telemetry.assert_called_once_with(["event1", "event2"])
assert client._events_batch == []

def test_close(self, telemetry_client_setup):
"""Test closing the client."""
client = telemetry_client_setup["client"]
client._flush = MagicMock()
client.flush = MagicMock()

client.close()

client._flush.assert_called_once()
client.flush.assert_called_once()

@patch("requests.post")
def test_telemetry_request_callback_success(self, mock_post, telemetry_client_setup):
Expand Down
Loading