diff --git a/packages/opal-client/opal_client/callbacks/reporter.py b/packages/opal-client/opal_client/callbacks/reporter.py index 7b636c83a..12adb779b 100644 --- a/packages/opal-client/opal_client/callbacks/reporter.py +++ b/packages/opal-client/opal_client/callbacks/reporter.py @@ -85,5 +85,5 @@ async def report_update_results( status=result.status, error=error_content, ) - except: - logger.exception("Failed to execute report_update_results") + except Exception as e: + logger.exception(f"Failed to execute report_update_results: {e}") diff --git a/packages/opal-client/opal_client/data/fetcher.py b/packages/opal-client/opal_client/data/fetcher.py index b96c27d39..edb38944b 100644 --- a/packages/opal-client/opal_client/data/fetcher.py +++ b/packages/opal-client/opal_client/data/fetcher.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from opal_client.config import opal_client_config from opal_client.policy_store.base_policy_store_client import JsonableValue @@ -58,8 +58,8 @@ async def stop(self): await self._engine.terminate_workers() async def handle_url( - self, url: str, config: FetcherConfig, data: Optional[JsonableValue] - ): + self, url: str, config: dict, data: Optional[JsonableValue] + ) -> Optional[JsonableValue]: """Helper function wrapping self._engine.handle_url.""" if data is not None: logger.info("Data provided inline for url: {url}", url=url) @@ -107,7 +107,7 @@ async def handle_urls( results_with_url_and_config = [ (url, config, result) for (url, config, data), result in zip(urls, results) - if result is not None + if result is not None # FIXME ignores None results ] # return results diff --git a/packages/opal-client/opal_client/data/updater.py b/packages/opal-client/opal_client/data/updater.py index c99ca3884..1c9dd9306 100644 --- a/packages/opal-client/opal_client/data/updater.py +++ b/packages/opal-client/opal_client/data/updater.py @@ -1,10 +1,9 @@ import asyncio import hashlib -import itertools import json import uuid from functools import partial -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import aiohttp from aiohttp.client import ClientError, ClientSession @@ -20,13 +19,13 @@ from opal_client.policy_store.base_policy_store_client import ( BasePolicyStoreClient, JsonableValue, + PolicyStoreTransactionContextManager, ) from opal_client.policy_store.policy_store_client_factory import ( DEFAULT_POLICY_STORE_GETTER, ) -from opal_common.async_utils import TakeANumberQueue, TasksPool, repeated_call +from opal_common.async_utils import TasksPool, repeated_call from opal_common.config import opal_common_config -from opal_common.fetcher.events import FetcherConfig from opal_common.http_utils import is_http_error_response from opal_common.schemas.data import ( DataEntryReport, @@ -37,11 +36,26 @@ ) from opal_common.schemas.store import TransactionType from opal_common.security.sslcontext import get_custom_ssl_context +from opal_common.synchronization.hierarchical_lock import HierarchicalLock from opal_common.utils import get_authorization_header from pydantic.json import pydantic_encoder class DataUpdater: + """The DataUpdater is responsible for synchronizing data sources with the + policy store (e.g. OPA). It listens to Pub/Sub topics for data updates, + fetches the updated data, and writes it into the policy store. The updater + also supports a "base fetch" flow on startup or reconnection, pulling data + from a configuration endpoint. + + Key Responsibilities: + - Subscribe to data update topics. + - Fetch new or changed data (using Fetchers, e.g. HTTP) + - Write updates to the policy store, ensuring concurrency safety. + - Periodically poll data sources (if configured). + - Report or callback the outcome of data updates (if configured). + """ + def __init__( self, token: str = None, @@ -50,7 +64,7 @@ def __init__( fetch_on_connect: bool = True, data_topics: List[str] = None, policy_store: BasePolicyStoreClient = None, - should_send_reports=None, + should_send_reports: Optional[bool] = None, data_fetcher: Optional[DataFetcher] = None, callbacks_register: Optional[CallbacksRegister] = None, opal_client_id: str = None, @@ -58,18 +72,23 @@ def __init__( on_connect: List[PubSubOnConnectCallback] = None, on_disconnect: List[OnDisconnectCallback] = None, ): - """Keeps policy-stores (e.g. OPA) up to date with relevant data Obtains - data configuration on startup from OPAL-server Uses Pub/Sub to - subscribe to data update events, and fetches (using FetchingEngine) - data from sources. + """Initializes the DataUpdater with the necessary configuration and + clients. Args: token (str, optional): Auth token to include in connections to OPAL server. Defaults to CLIENT_TOKEN. pubsub_url (str, optional): URL for Pub/Sub updates for data. Defaults to OPAL_SERVER_PUBSUB_URL. data_sources_config_url (str, optional): URL to retrieve base data configuration. Defaults to DEFAULT_DATA_SOURCES_CONFIG_URL. - fetch_on_connect (bool, optional): Should the update fetch basic data immediately upon connection/reconnection. Defaults to True. - data_topics (List[str], optional): Topics of data to fetch and subscribe to. Defaults to DATA_TOPICS. - policy_store (BasePolicyStoreClient, optional): Policy store client to use to store data. Defaults to DEFAULT_POLICY_STORE. + fetch_on_connect (bool, optional): Whether to fetch all data immediately upon connection. + data_topics (List[str], optional): Pub/Sub topics to subscribe to. Defaults to DATA_TOPICS. + policy_store (BasePolicyStoreClient, optional): The client used to store data. Defaults to DEFAULT_POLICY_STORE. + should_send_reports (bool, optional): Whether to report on data updates to callbacks. Defaults to SHOULD_REPORT_ON_DATA_UPDATES. + data_fetcher (DataFetcher, optional): Custom data fetching engine. + callbacks_register (CallbacksRegister, optional): Manages user-defined callbacks. + opal_client_id (str, optional): A unique identifier for this OPAL client. + shard_id (str, optional): A partition/shard identifier. Translates to an HTTP header. + on_connect (List[PubSubOnConnectCallback], optional): Extra on-connect callbacks. + on_disconnect (List[OnDisconnectCallback], optional): Extra on-disconnect callbacks. """ # Defaults token: str = token or opal_client_config.CLIENT_TOKEN @@ -88,13 +107,14 @@ def __init__( data_sources_config_url = ( f"{opal_client_config.SERVER_URL}/scopes/{self._scope_id}/data" ) + # Namespacing the data topics for the specific scope self._data_topics = [ f"{self._scope_id}:data:{topic}" for topic in self._data_topics ] - # Should the client use the default data source to fetch on connect + # Should the client fetch data when it first connects (or reconnects) self._fetch_on_connect = fetch_on_connect - # The policy store we'll save data updates into + # Policy store client self._policy_store = policy_store or DEFAULT_POLICY_STORE_GETTER() self._should_send_reports = ( @@ -102,21 +122,23 @@ def __init__( if should_send_reports is not None else opal_client_config.SHOULD_REPORT_ON_DATA_UPDATES ) - # The pub/sub client for data updates - self._client = None - # The task running the Pub/Sub subscribing client - self._subscriber_task = None - # Data fetcher + + # Will be set once we subscribe and connect + self._client: Optional[PubSubClient] = None + self._subscriber_task: Optional[asyncio.Task] = None + + # DataFetcher is a helper that can handle different data sources (HTTP, local, etc.) self._data_fetcher = data_fetcher or DataFetcher() self._callbacks_register = callbacks_register or CallbacksRegister() - self._callbacks_reporter = CallbacksReporter( - self._callbacks_register, - ) + self._callbacks_reporter = CallbacksReporter(self._callbacks_register) + self._token = token self._shard_id = shard_id self._server_url = pubsub_url self._data_sources_config_url = data_sources_config_url self._opal_client_id = opal_client_id + + # Prepare any extra headers (token, shard id, etc.) self._extra_headers = [] if self._token is not None: self._extra_headers.append(get_authorization_header(self._token)) @@ -124,17 +146,23 @@ def __init__( self._extra_headers.append(("X-Shard-ID", self._shard_id)) if len(self._extra_headers) == 0: self._extra_headers = None + self._stopping = False - # custom SSL context (for self-signed certificates) self._custom_ssl_context = get_custom_ssl_context() self._ssl_context_kwargs = ( - {"ssl": self._custom_ssl_context} - if self._custom_ssl_context is not None - else {} + {"ssl": self._custom_ssl_context} if self._custom_ssl_context else {} ) - self._updates_storing_queue = TakeANumberQueue(logger) + + # TaskGroup to manage data updates and callbacks background tasks (with graceful shutdown) self._tasks = TasksPool() + + # Lock to prevent multiple concurrent writes to the same path + self._dst_lock = HierarchicalLock() + + # References to repeated polling tasks (periodic data fetch) self._polling_update_tasks = [] + + # Optional user-defined hooks for connection lifecycle self._on_connect_callbacks = on_connect or [] self._on_disconnect_callbacks = on_disconnect or [] @@ -143,45 +171,64 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): - """Context handler to terminate internal tasks.""" if not self._stopping: await self.stop() - async def _update_policy_data_callback(self, data: dict = None, topic=""): - """ - Pub/Sub callback - triggering data updates - will run when we get notifications on the policy_data topic. - i.e: when new roles are added, changes to permissions, etc. + async def _update_policy_data_callback(self, data: Optional[dict] = None, topic=""): + """Callback invoked by the Pub/Sub client whenever a data update is + published on one of our subscribed topics. + + Calls trigger_data_update() with the DataUpdate object extracted + from 'data'. """ if data is not None: reason = data.get("reason", "") else: reason = "Periodic update" + logger.info("Updating policy data, reason: {reason}", reason=reason) update = DataUpdate.parse_obj(data) await self.trigger_data_update(update) async def trigger_data_update(self, update: DataUpdate): - # make sure the id has a unique id for tracking + """Queues up a data update to run in the background. If no update ID is + provided, generate one for tracking/logging. + + Note: + We spin off the data update in the background so that multiple updates + can run concurrently. Internally, the `_update_policy_data` method uses + a hierarchical lock to avoid race conditions when multiple updates try + to write to the same destination path. + """ + # Ensure we have a unique update ID if update.id is None: update.id = uuid.uuid4().hex + logger.info("Triggering data update with id: {id}", id=update.id) - # Fetching should be concurrent, but storing should be done in the original order - store_queue_number = await self._updates_storing_queue.take_a_number() - self._tasks.add_task(self._update_policy_data(update, store_queue_number)) + # Run the update in the background concurrently with other updates + # The TaskGroup will manage the lifecycle of this task, + # managing graceful shutdown of the updater without losing running data updates + self._tasks.add_task(self._update_policy_data(update)) async def get_policy_data_config(self, url: str = None) -> DataSourceConfig: - """ - Get the configuration for + """Fetches the DataSourceConfig (list of DataSourceEntry) from the + provided URL. + Args: - url: the URL to query for the config, Defaults to self._data_sources_config_url + url (str, optional): The URL to fetch data sources config from. Defaults to + self._data_sources_config_url if None is given. + + Raises: + ClientError: If the server responds with an error status. + Returns: - DataSourceConfig: the data sources config + DataSourceConfig: The parsed config containing data entries. """ if url is None: url = self._data_sources_config_url logger.info("Getting data-sources configuration from '{source}'", source=url) + try: async with ClientSession(headers=self._extra_headers) as session: response = await session.get(url, **self._ssl_context_kwargs) @@ -193,39 +240,47 @@ async def get_policy_data_config(self, url: str = None) -> DataSourceConfig: f"Fetch data sources failed with status code {response.status}, error: {error_details}" ) except: - logger.exception(f"Failed to load data sources config") + logger.exception("Failed to load data sources config") raise async def get_base_policy_data( self, config_url: str = None, data_fetch_reason="Initial load" ): - """Load data into the policy store according to the data source's - config provided in the config URL. + """Fetches an initial (or base) set of data from the configuration URL + and stores it in the policy store. + + This method also sets up any periodic data polling tasks for entries + that specify a 'periodic_update_interval'. Args: - config_url (str, optional): URL to retrieve data sources config from. Defaults to None ( self._data_sources_config_url). - data_fetch_reason (str, optional): Reason to log for the update operation. Defaults to "Initial load". + config_url (str, optional): A specific config URL to fetch from. If not given, + uses self._data_sources_config_url. + data_fetch_reason (str, optional): Reason for logging this fetch. Defaults to + "Initial load". """ logger.info( "Performing data configuration, reason: {reason}", reason=data_fetch_reason ) - await self._stop_polling_update_tasks() # If this is a reconnect - should stop previously received periodic updates + + # If we're reconnecting, stop any old periodic tasks before fetching anew + await self._stop_polling_update_tasks() + + # Fetch the base config with all data entries sources_config = await self.get_policy_data_config(url=config_url) init_entries, periodic_entries = [], [] for entry in sources_config.entries: - ( - periodic_entries - if (entry.periodic_update_interval is not None) - else init_entries - ).append(entry) + if entry.periodic_update_interval is not None: + periodic_entries.append(entry) + else: + init_entries.append(entry) - # Process one time entries now + # Process one-time entries now update = DataUpdate(reason=data_fetch_reason, entries=init_entries) await self.trigger_data_update(update) - # Schedule repeated processing of periodic polling entries - async def _trigger_update_with_entry(entry): + # Schedule repeated processing (polling) of periodic entries + async def _trigger_update_with_entry(entry: DataSourceEntry): await self.trigger_data_update( DataUpdate(reason="Periodic Update", entries=[entry]) ) @@ -239,20 +294,18 @@ async def _trigger_update_with_entry(entry): self._polling_update_tasks.append(asyncio.create_task(repeat_process_entry)) async def on_connect(self, client: PubSubClient, channel: RpcChannel): - """Pub/Sub on_connect callback On connection to backend, whether its - the first connection, or reconnecting after downtime, refetch the state - opa needs. + """Invoked when the Pub/Sub client establishes a connection to the + server. - As long as the connection is alive we know we are in sync with - the server, when the connection is lost we assume we need to - start from scratch. + By default, this re-fetches base policy data. Also publishes a + statistic event if statistics are enabled. """ logger.info("Connected to server") if self._fetch_on_connect: await self.get_base_policy_data() if opal_common_config.STATISTICS_ENABLED: + # Publish stats about the newly connected client await self._client.wait_until_ready() - # publish statistics to the server about new connection from client (only if STATISTICS_ENABLED is True, default to False) await self._client.publish( [opal_common_config.STATISTICS_ADD_CLIENT_CHANNEL], data={ @@ -263,21 +316,30 @@ async def on_connect(self, client: PubSubClient, channel: RpcChannel): ) async def on_disconnect(self, channel: RpcChannel): + """Invoked when the Pub/Sub client disconnects from the server.""" logger.info("Disconnected from server") async def start(self): + """ + Starts the DataUpdater: + - Begins listening for Pub/Sub data update events. + - Starts the callbacks reporter for asynchronous callback tasks. + - Starts the DataFetcher if not already running. + """ logger.info("Launching data updater") await self._callbacks_reporter.start() - await self._updates_storing_queue.start_queue_handling( - self._store_fetched_update - ) + if self._subscriber_task is None: + # The subscriber task runs in the background, receiving data update events self._subscriber_task = asyncio.create_task(self._subscriber()) await self._data_fetcher.start() async def _subscriber(self): - """Coroutine meant to be spunoff with create_task to listen in the - background for data events and pass them to the data_fetcher.""" + """The main loop for subscribing to Pub/Sub topics. + + Waits for data update notifications and dispatches them to our + callback. + """ logger.info("Subscribing to topics: {topics}", topics=self._data_topics) self._client = PubSubClient( self._data_topics, @@ -294,17 +356,28 @@ async def _subscriber(self): await self._client.wait_until_done() async def _stop_polling_update_tasks(self): - if len(self._polling_update_tasks) > 0: + """Cancels all periodic polling tasks (if any). + + Used on reconnection or shutdown to ensure we don't have stale + tasks still running. + """ + if self._polling_update_tasks: for task in self._polling_update_tasks: task.cancel() await asyncio.gather(*self._polling_update_tasks, return_exceptions=True) self._polling_update_tasks = [] async def stop(self): + """ + Cleanly shuts down the DataUpdater: + - Disconnects the Pub/Sub client. + - Stops polling tasks. + - Cancels the subscriber background task. + - Stops the data fetcher and callback reporter. + """ self._stopping = True logger.info("Stopping data updater") - # disconnect from Pub/Sub if self._client is not None: try: await asyncio.wait_for(self._client.disconnect(), timeout=3) @@ -313,10 +386,9 @@ async def stop(self): "Timeout waiting for DataUpdater pubsub client to disconnect" ) - # stop periodic updates await self._stop_polling_update_tasks() - # stop subscriber task + # Cancel the subscriber task if self._subscriber_task is not None: logger.debug("Cancelling DataUpdater subscriber task") self._subscriber_task.cancel() @@ -330,216 +402,301 @@ async def stop(self): self._subscriber_task = None logger.debug("DataUpdater subscriber task was cancelled") - # stop the data fetcher + # Stop the DataFetcher logger.debug("Stopping data fetcher") await self._data_fetcher.stop() - # stop queue handling - await self._updates_storing_queue.stop_queue_handling() - - # stop the callbacks reporter + # Stop the callbacks reporter await self._callbacks_reporter.stop() + # Exit the TaskGroup context + await self._tasks.shutdown() + async def wait_until_done(self): + """Blocks until the Pub/Sub subscriber task completes. + + Typically, this runs indefinitely unless a stop/shutdown event + occurs. + """ if self._subscriber_task is not None: await self._subscriber_task @staticmethod - def calc_hash(data): - """Calculate an hash (sah256) on the given data, if data isn't a - string, it will be converted to JSON. + def calc_hash(data: JsonableValue) -> str: + """Calculates a SHA-256 hash of the given data to be used to identify + the updates (e.g. in logging reports on the transactions) . If 'data' + is not a string, it is first serialized to JSON. Returns an empty + string on failure. + + Args: + data (JsonableValue): The data to be hashed. - String are encoded as 'utf-8' prior to hash calculation. Returns: - the hash of the given data (as a a hexdigit string) or '' on failure to process. + str: The hexadecimal representation of the SHA-256 hash. """ try: if not isinstance(data, str): data = json.dumps(data, default=pydantic_encoder) return hashlib.sha256(data.encode("utf-8")).hexdigest() - except: - logger.exception("Failed to calculate hash for data {data}", data=data) + except Exception as e: + logger.exception(f"Failed to calculate hash for data {data}: {e}") return "" - async def _update_policy_data( - self, - update: DataUpdate, - store_queue_number: TakeANumberQueue.Number, - ): - """Fetches policy data (policy configuration) from backend and updates - it into policy-store (i.e. OPA)""" - - if update is None: - return - - # types / defaults - urls: List[Tuple[str, FetcherConfig, Optional[JsonableValue]]] = None - entries: List[DataSourceEntry] = [] - # if we have an actual specification for the update - if update is not None: - # Check each entry's topics to only process entries designated to us - entries = [ - entry - for entry in update.entries - if entry.topics - and not set(entry.topics).isdisjoint(set(self._data_topics)) - ] - urls = [] - for entry in entries: - config = entry.config - if self._shard_id is not None: - headers = config.get("headers", {}) - headers.update({"X-Shard-ID": self._shard_id}) - config["headers"] = headers - urls.append((entry.url, config, entry.data)) - - if len(entries) > 0: - logger.info("Fetching policy data", urls=repr(urls)) - else: - logger.warning( - "None of the update's entries are designated to subscribed topics" + async def _update_policy_data(self, update: DataUpdate) -> None: + """Performs the core data update process for the given DataUpdate + object. + + Steps: + 1. Iterate over the DataUpdate entries. + 2. For each entry, check if any of its topics match our client's topics. + 3. Acquire a lock for the destination path, so we don't fetch and overwrite concurrently. + - Note: This means that fetches that can technically happen concurrently wait on one another. + This can be improved with a Fetcher-Writer Lock ( a la Reader-Writer Lock ) pattern + 4. Fetch the data from the source (if applicable). + 5. Write the data into the policy store. + 6. Collect a report (success/failure, hash of the data, etc.). + 7. Send a consolidated report after processing all entries. + + Args: + update (DataUpdate): The data update instructions (entries, reason, etc.). + + Returns: + None + """ + reports: list[DataEntryReport] = [] + + for entry in update.entries: + if not entry.topics: + logger.debug("Data entry {entry} has no topics, skipping", entry=entry) + continue + + # Only process entries that match one of our subscribed data topics + if set(entry.topics).isdisjoint(set(self._data_topics)): + logger.debug( + "Data entry {entry} has no topics matching the data topics, skipping", + entry=entry, + ) + continue + + transaction_context = self._policy_store.transaction_context( + update.id, transaction_type=TransactionType.data ) - # Urls may be None - handle_urls has a default for None - policy_data_with_urls = await self._data_fetcher.handle_urls(urls) - store_queue_number.put((update, entries, policy_data_with_urls)) - - async def _store_fetched_update(self, update_item): - (update, entries, policy_data_with_urls) = update_item - - # track the result of each url in order to report back - reports: List[DataEntryReport] = [] - - # Save the data from the update - # We wrap our interaction with the policy store with a transaction - async with self._policy_store.transaction_context( - update.id, transaction_type=TransactionType.data - ) as store_transaction: - # for intellisense treat store_transaction as a PolicyStoreClient (which it proxies) - store_transaction: BasePolicyStoreClient - error_content = None - for (url, fetch_config, result), entry in itertools.zip_longest( - policy_data_with_urls, entries + # Acquire a per-destination lock to avoid overwriting the same path concurrently + async with ( + transaction_context as store_transaction, + self._dst_lock.lock(entry.dst_path), ): - fetched_data_successfully = True - - if isinstance(result, Exception): - fetched_data_successfully = False - logger.error( - "Failed to fetch url {url}, got exception: {exc}", - url=url, - exc=result, - ) + report = await self._fetch_and_save_data(entry, store_transaction) - if isinstance( - result, aiohttp.ClientResponse - ) and is_http_error_response( - result - ): # error responses - fetched_data_successfully = False - try: - error_content = await result.json() - logger.error( - "Failed to fetch url {url}, got response code {status} with error: {error}", - url=url, - status=result.status, - error=error_content, - ) - except json.JSONDecodeError: - error_content = await result.text() - logger.error( - "Failed to decode response from url:{url}, got response code {status} with response: {error}", - url=url, - status=result.status, - error=error_content, - ) - store_transaction._update_remote_status( - url=url, - status=fetched_data_successfully, - error=str(error_content), - ) + reports.append(report) - if fetched_data_successfully: - # get path to store the URL data (default mode (None) is as "" - i.e. as all the data at root) - policy_store_path = "" if entry is None else entry.dst_path - # None is not valid - use "" (protect from missconfig) - if policy_store_path is None: - policy_store_path = "" - # fix opa_path (if not empty must start with "/" to be nested under data) - if policy_store_path != "" and not policy_store_path.startswith( - "/" - ): - policy_store_path = f"/{policy_store_path}" - policy_data = result - # Create a report on the data-fetching - report = DataEntryReport( - entry=entry, hash=self.calc_hash(policy_data), fetched=True - ) + await self._send_reports(reports, update) - try: - if ( - opal_client_config.SPLIT_ROOT_DATA - and policy_store_path in ("/", "") - and isinstance(policy_data, dict) - ): - await self._set_split_policy_data( - store_transaction, - url=url, - save_method=entry.save_method, - data=policy_data, - ) - else: - await self._set_policy_data( - store_transaction, - url=url, - path=policy_store_path, - save_method=entry.save_method, - data=policy_data, - ) - # No exception we we're able to save to the policy-store - report.saved = True - # save the report for the entry - reports.append(report) - except Exception: - logger.exception("Failed to save data update to policy-store") - # we failed to save to policy-store - report.saved = False - # save the report for the entry - reports.append(report) - # re-raise so the context manager will be aware of the failure - raise - else: - report = DataEntryReport(entry=entry, fetched=False, saved=False) - # save the report for the entry - reports.append(report) - # should we send a report to defined callbackers? + async def _send_reports(self, reports: list[DataEntryReport], update: DataUpdate): + """Handles the reporting of completed data updates back to callbacks. + + Args: + reports (List[DataEntryReport]): List of individual entry reports. + update (DataUpdate): The overall DataUpdate object (contains reason, etc.). + """ if self._should_send_reports: - # spin off reporting (no need to wait on it) + # Merge into a single DataUpdateReport whole_report = DataUpdateReport(update_id=update.id, reports=reports) extra_callbacks = self._callbacks_register.normalize_callbacks( update.callback.callbacks ) + # Asynchronously send the report to any configured callbacks self._tasks.add_task( self._callbacks_reporter.report_update_results( whole_report, extra_callbacks ) ) + async def _fetch_and_save_data( + self, + entry: DataSourceEntry, + store_transaction: PolicyStoreTransactionContextManager, + ) -> DataEntryReport: + """Orchestrates fetching data from a source and saving it into the + policy store. + + Flow: + 1. Attempt to fetch data via the data fetcher (e.g., HTTP). + 2. If data is fetched successfully, store it in the policy store. + 3. Return a DataEntryReport indicating success/failure of each step. + + Args: + entry (DataSourceEntry): The configuration details of the data source entry. + store_transaction (PolicyStoreTransactionContextManager): An active + transaction to the policy store. + + Returns: + DataEntryReport: Includes information about whether data was fetched, + saved, and the computed hash for the data if successfully saved. + """ + try: + result = await self._fetch_data(entry) + except Exception as e: + store_transaction._update_remote_status( + url=entry.url, status=False, error=str(e) + ) + return DataEntryReport(entry=entry, fetched=False, saved=False) + + try: + await self._store_fetched_data(entry, result, store_transaction) + except Exception as e: + logger.exception("Failed to save data update to policy-store: {exc}", exc=e) + store_transaction._update_remote_status( + url=entry.url, + status=False, + error=f"Failed to save data to policy store: {e}", + ) + return DataEntryReport( + entry=entry, hash=self.calc_hash(result), fetched=True, saved=False + ) + else: + store_transaction._update_remote_status( + url=entry.url, status=True, error="" + ) + return DataEntryReport( + entry=entry, hash=self.calc_hash(result), fetched=True, saved=True + ) + + async def _fetch_data(self, entry: DataSourceEntry) -> JsonableValue: + """Fetches data from a data source using the configured data fetcher. + Handles fetch errors, HTTP errors, and empty responses. + + Args: + entry (DataSourceEntry): The configuration specifying how and where to fetch data. + + Returns: + JsonableValue: The fetched data, as a JSON-serializable object. + """ + try: + result = await self._data_fetcher.handle_url( + url=entry.url, + config=entry.config, + data=entry.data, + ) + except Exception as e: + logger.exception( + "Failed to fetch data for entry {entry} with exception {exc}", + entry=entry, + exc=e, + ) + raise Exception(f"Failed to fetch data for entry {entry.url}: {e}") + + if result is None: + raise Exception(f"Fetched data is empty for entry {entry.url}") + + if isinstance(result, aiohttp.ClientResponse) and is_http_error_response( + result + ): + error_content = await result.text() + logger.error( + "Failed to decode response from url: '{url}', got response code {status} with response: {error}", + url=entry.url, + status=result.status, + error=error_content, + ) + raise Exception( + f"Failed to decode response from url: '{entry.url}', got response code {result.status} with response: {error_content}" + ) + + return result + + async def _store_fetched_data( + self, + entry: DataSourceEntry, + result: JsonableValue, + store_transaction: PolicyStoreTransactionContextManager, + ) -> None: + """Decides how to store fetched data (entirely or split by root keys) + in the policy store based on the configuration. + + Args: + entry (DataSourceEntry): The configuration specifying how and where to store data. + result (JsonableValue): The fetched data to be stored. + store_transaction (PolicyStoreTransactionContextManager): The policy store + transaction under which to perform the write operations. + + Raises: + Exception: If storing data fails for any reason. + """ + policy_store_path = entry.dst_path or "" + if policy_store_path and not policy_store_path.startswith("/"): + policy_store_path = f"/{policy_store_path}" + + # If splitting root-level data is enabled and the path is "/", each top-level key + # is stored individually to avoid overwriting the entire data root. + if ( + opal_client_config.SPLIT_ROOT_DATA + and policy_store_path in ("/", "") + and isinstance(result, dict) + ): + await self._set_split_policy_data( + store_transaction, + url=entry.url, + save_method=entry.save_method, + data=result, + ) + else: + await self._set_policy_data( + store_transaction, + url=entry.url, + path=policy_store_path, + save_method=entry.save_method, + data=result, + ) + async def _set_split_policy_data( - self, tx, url: str, save_method: str, data: Dict[str, Any] + self, + tx: PolicyStoreTransactionContextManager, + url: str, + save_method: str, + data: Dict[str, Any], ): - """Split data writes to root ("/") path, so they won't overwrite other - sources.""" + """Splits data writes for root path ("/") so we don't overwrite + existing keys. + + For each top-level key in the dictionary, we create a sub-path under "/" + and save the corresponding value. + + Args: + tx (PolicyStoreTransactionContextManager): The active store transaction. + url (str): The data source URL (used for logging/reporting). + save_method (str): Either "PUT" (full overwrite) or "PATCH" (merge). + data (Dict[str, Any]): The dictionary to be split and stored. + """ logger.info("Splitting root data to {n} keys", n=len(data)) for prefix, obj in data.items(): await self._set_policy_data( - tx, url=url, path=f"/{prefix}", save_method=save_method, data=obj + tx, + url=url, + path=f"/{prefix}", + save_method=save_method, + data=obj, ) async def _set_policy_data( - self, tx, url: str, path: str, save_method: str, data: JsonableValue + self, + tx: PolicyStoreTransactionContextManager, + url: str, + path: str, + save_method: str, + data: JsonableValue, ): + """Persists data to a specific path in the policy store. + + Args: + tx (PolicyStoreTransactionContextManager): The active store transaction. + url (str): The URL of the source data (used for logging/reporting). + path (str): The policy store path where data will be stored (e.g. "/roles"). + save_method (str): Either "PUT" (full overwrite) or "PATCH" (partial merge). + data (JsonableValue): The data to be written. + """ logger.info( "Saving fetched data to policy-store: source url='{url}', destination path='{path}'", url=url, @@ -552,4 +709,11 @@ async def _set_policy_data( @property def callbacks_reporter(self) -> CallbacksReporter: + """Provides external access to the CallbacksReporter instance, so that + users of DataUpdater can register custom callbacks or manipulate + reporting flows. + + Returns: + CallbacksReporter: The internal callbacks reporter. + """ return self._callbacks_reporter diff --git a/packages/opal-client/opal_client/tests/data_updater_test.py b/packages/opal-client/opal_client/tests/data_updater_test.py index f2b27b0fb..c7a82f4bd 100644 --- a/packages/opal-client/opal_client/tests/data_updater_test.py +++ b/packages/opal-client/opal_client/tests/data_updater_test.py @@ -196,17 +196,18 @@ async def test_data_updater(server): proc.terminate() # test PATCH update event via API - entries = [ - DataSourceEntry( - url="", - data=PATCH_DATA_UPDATE, - dst_path="/", - topics=DATA_TOPICS, - save_method="PATCH", - ) - ] update = DataUpdate( - reason="Test_Patch", entries=entries, callback=UpdateCallback(callbacks=[]) + reason="Test_Patch", + entries=[ + DataSourceEntry( + url="", + data=PATCH_DATA_UPDATE, + dst_path="/", + topics=DATA_TOPICS, + save_method="PATCH", + ) + ], + callback=UpdateCallback(callbacks=[]), ) headers = {"content-type": "application/json"} @@ -218,13 +219,26 @@ async def test_data_updater(server): ) assert res.status_code == 200 # value field is not specified for add operation should fail - entries[0].data = [{"op": "add", "path": "/"}] res = requests.post( DATA_UPDATE_ROUTE, - data=json.dumps(update, default=pydantic_encoder), + data=json.dumps( + { + "reason": "Test_Patch", + "entries": [ + { + "url": "", + "data": [{"op": "add", "path": "/"}], + "dst_path": "/", + "topics": DATA_TOPICS, + "save_method": "PATCH", + } + ], + }, + default=pydantic_encoder, + ), headers=headers, ) - assert res.status_code == 422 + assert res.status_code == 422, res.text @pytest.mark.asyncio diff --git a/packages/opal-common/opal_common/async_utils.py b/packages/opal-common/opal_common/async_utils.py index a2df90c69..fbed27cf9 100644 --- a/packages/opal-common/opal_common/async_utils.py +++ b/packages/opal-common/opal_common/async_utils.py @@ -3,9 +3,10 @@ import asyncio import sys from functools import partial -from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Coroutine, Optional, Set, Tuple, TypeVar import loguru +from loguru import logger if sys.version_info < (3, 10): from typing_extensions import ParamSpec @@ -94,16 +95,40 @@ async def stop_queue_handling(self): class TasksPool: def __init__(self): - self._tasks: List[asyncio.Task] = [] + self._tasks: Set[asyncio.Task] = set() + self._running = True def _cleanup_task(self, done_task): self._tasks.remove(done_task) def add_task(self, f): + if not self._running: + raise RuntimeError("TasksPool is already shutdown") t = asyncio.create_task(f) - self._tasks.append(t) + self._tasks.add(t) t.add_done_callback(self._cleanup_task) + async def shutdown(self, force: bool = False): + """Wait for them to finish. + + :param force: If True, cancel all tasks immediately. + """ + self._running = False + if force: + for t in self._tasks: + t.cancel() + + results = await asyncio.gather( + *self._tasks, + return_exceptions=True, + ) + for result in results: + if isinstance(result, Exception): + logger.exception( + "Error on task during shutdown of TasksPool: {result}", + result=result, + ) + async def repeated_call( func: Coroutine, diff --git a/packages/opal-common/opal_common/fetcher/engine/fetching_engine.py b/packages/opal-common/opal_common/fetcher/engine/fetching_engine.py index b439d4b8d..1ab224707 100644 --- a/packages/opal-common/opal_common/fetcher/engine/fetching_engine.py +++ b/packages/opal-common/opal_common/fetcher/engine/fetching_engine.py @@ -124,7 +124,7 @@ async def queue_url( self, url: str, callback: Coroutine, - config: Union[FetcherConfig, dict] = None, + config: Union[FetcherConfig, dict, None] = None, fetcher="HttpFetchProvider", ) -> FetchEvent: """Simplified default fetching handler for queuing a fetch task. diff --git a/packages/opal-common/opal_common/synchronization/hierarchical_lock.py b/packages/opal-common/opal_common/synchronization/hierarchical_lock.py new file mode 100644 index 000000000..61aae3023 --- /dev/null +++ b/packages/opal-common/opal_common/synchronization/hierarchical_lock.py @@ -0,0 +1,85 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Set + + +class HierarchicalLock: + """A hierarchical lock for asyncio. + + - If a path is locked, no ancestor or descendant path can be locked. + - Conversely, if a child path is locked, the parent path cannot be locked + until all child paths are released. + """ + + def __init__(self): + # locked_paths: set of currently locked string paths + self._locked_paths: Set[str] = set() + # Map of tasks to their acquired locks for re-entrant protection + self._task_locks: dict[asyncio.Task, Set[str]] = {} + # Internal lock for synchronizing access to locked_paths + self._lock = asyncio.Lock() + # Condition to wake up tasks when a path is released + self._cond = asyncio.Condition(self._lock) + + @staticmethod + def _is_conflicting(p1: str, p2: str) -> bool: + """Check if two paths conflict with each other.""" + return p1 == p2 or p1.startswith(p2) or p2.startswith(p1) + + async def acquire(self, path: str): + """Acquire the lock for the given hierarchical path. + + If an ancestor or descendant path is locked, this will wait + until it is released. + """ + task = asyncio.current_task() + if task is None: + raise RuntimeError("acquire() must be called from within a task.") + + async with self._lock: + # Prevent re-entrant locking by the same task + if path in self._task_locks.get(task, set()): + raise RuntimeError(f"Task {task} cannot re-acquire lock on '{path}'.") + + # Wait until there is no conflict with existing locked paths + while any(self._is_conflicting(path, lp) for lp in self._locked_paths): + await self._cond.wait() + + # Acquire the path + self._locked_paths.add(path) + if task not in self._task_locks: + self._task_locks[task] = set() + self._task_locks[task].add(path) + + async def release(self, path: str): + """Release the lock for the given path and notify waiting tasks.""" + task = asyncio.current_task() + if task is None: + raise RuntimeError("release() must be called from within a task.") + + async with self._lock: + if path not in self._locked_paths: + raise RuntimeError(f"Cannot release path '{path}' that is not locked.") + + if path not in self._task_locks.get(task, set()): + raise RuntimeError( + f"Task {task} cannot release lock on '{path}' it does not hold." + ) + + # Remove the path from locked paths and task locks + self._locked_paths.remove(path) + self._task_locks[task].remove(path) + if not self._task_locks[task]: + del self._task_locks[task] + + # Notify all tasks that something was released + self._cond.notify_all() + + @asynccontextmanager + async def lock(self, path: str) -> "HierarchicalLock": + """Acquire the lock for the given path and return a context manager.""" + await self.acquire(path) + try: + yield self + finally: + await self.release(path) diff --git a/packages/opal-common/opal_common/tests/hierarchical_lock_test.py b/packages/opal-common/opal_common/tests/hierarchical_lock_test.py new file mode 100644 index 000000000..7579beb7f --- /dev/null +++ b/packages/opal-common/opal_common/tests/hierarchical_lock_test.py @@ -0,0 +1,209 @@ +import asyncio +from typing import Coroutine + +import pytest +from opal_common.synchronization.hierarchical_lock import HierarchicalLock + + +async def measure_duration(coro: Coroutine) -> float: + loop = asyncio.get_event_loop() + start_time = loop.time() + await coro + return loop.time() - start_time + + +@pytest.mark.asyncio +async def test_non_conflicting_paths(): + lock = HierarchicalLock() + + # Acquire a path for alice and a path for bob + # They should not block each other + async def lock_path(path): + async with lock.lock(path): + await asyncio.sleep(0.1) + + t1 = lock_path("alice") + t2 = lock_path("bob") + + # If both tasks complete quickly, the test passes. + duration = await measure_duration( + asyncio.wait_for( + asyncio.gather(t1, t2), + timeout=10, + ) + ) + assert duration < 0.2, "Both paths should acquire lock concurrently" + + +@pytest.mark.asyncio +async def test_siblings_do_not_block(): + lock = HierarchicalLock() + + # Acquire two sibling paths concurrently + # They should not block each other + async def lock_sibling(path): + async with lock.lock(path): + await asyncio.sleep(0.1) + + t1 = lock_sibling("alice.age") + t2 = lock_sibling("alice.name") + + duration = await measure_duration( + asyncio.wait_for( + asyncio.gather(t1, t2), + timeout=10, + ) + ) + assert duration < 0.2, "Both siblings should acquire lock concurrently" + + +@pytest.mark.asyncio +async def test_parent_blocks_child(): + lock = HierarchicalLock() + + got_lock_child = asyncio.Event() + + async def lock_parent(): + await lock.acquire("alice") + # hold lock for some time so child attempts to acquire and is blocked + await asyncio.sleep(0.2) + await lock.release("alice") + + async def lock_child(): + await asyncio.sleep(0.1) # wait a moment so parent acquires first + await lock.acquire("alice.age") + got_lock_child.set() + await lock.release("alice.age") + + parent_task = lock_parent() + child_task = lock_child() + + # child should not be able to acquire immediately + # so we expect got_lock_child to not be set before 0.2s + await asyncio.sleep(0.15) + assert not got_lock_child.is_set(), "Child should be blocked by parent" + + # let everything finish + await asyncio.gather(parent_task, child_task) + assert got_lock_child.is_set(), "Child eventually acquires lock" + + +@pytest.mark.asyncio +async def test_children_block_parent(): + lock = HierarchicalLock() + + got_lock_parent = asyncio.Event() + + async def lock_child(path, delay=0): + await asyncio.sleep(delay) + async with lock.lock(path): + # hold it for some time + await asyncio.sleep(0.2) + + async def lock_parent(): + await asyncio.sleep(0.05) # ensure children get the lock first + async with lock.lock("alice"): + got_lock_parent.set() + + c1 = lock_child("alice.age", 0) + c2 = lock_child("alice.name", 0) + p = lock_parent() + + # Wait some time so the parent tries to acquire + # The parent should be blocked while children hold locks + await asyncio.sleep(0.1) + # Children have likely acquired their locks by now + assert not got_lock_parent.is_set(), "Parent should be blocked by child locks" + + await asyncio.gather(c1, c2, p) + assert got_lock_parent.is_set(), "Parent eventually acquires after children release" + + +# test same key block each other + + +@pytest.mark.asyncio +async def test_same_key_blocks(): + lock = HierarchicalLock() + + async def lock_task(): + async with lock.lock("alice"): + await asyncio.sleep(0.2) + + t1 = lock_task() + t2 = lock_task() + + duration = await measure_duration( + asyncio.wait_for( + asyncio.gather(t1, t2), + timeout=10, + ) + ) + assert duration >= 0.4, "Both tasks should not acquire lock concurrently" + + +@pytest.mark.asyncio +async def test_parent_waits_for_new_child(): + lock = HierarchicalLock() + + # We'll do a scenario: + # 1) Acquire alice.age, alice.name concurrently + # 2) Acquire alice -> must wait for both to release + # 3) While alice is waiting, acquire alice.height + # 4) Release all children, ensure alice eventually gets lock + + # We track the times we acquire the parent to ensure it's after all children. + parent_acquired = False + + async def child_locker(path: str, hold: float = 0.1, delay: float = 0.0): + await asyncio.sleep(delay) + async with lock.lock(path): + await asyncio.sleep(hold) + + async def parent_locker(): + nonlocal parent_acquired + # start after children are running + await asyncio.sleep(0.05) + async with lock.lock("alice"): + parent_acquired = True + + c1 = child_locker("alice.age", hold=0.2, delay=0.0) + c2 = child_locker("alice.name", hold=0.2, delay=0.0) + p = parent_locker() + + # after a short moment, start new child 'alice.height' + c3 = child_locker("alice.height", hold=0.2, delay=0.1) + + await asyncio.gather(c1, c2, c3, p) + + assert parent_acquired, "Parent should eventually acquire after all children" + + +@pytest.mark.asyncio +async def test_release_on_non_locked_path(): + lock = HierarchicalLock() + + with pytest.raises(RuntimeError): + await lock.release("non-locked-path") + + +@pytest.mark.asyncio +async def test_same_task_reacquire_same_key_deadlock(): + # Test whether the same coroutine can re-acquire a path it already holds. + # By default, a non-reentrant lock should deadlock or raise an error. + + lock = HierarchicalLock() + + async def same_task(): + # Acquire once + await lock.acquire("alice") + # This should either block forever or raise an error if not re-entrant + with pytest.raises(RuntimeError): + await lock.acquire("alice") + # Release after the test above + await lock.release("alice") + + await asyncio.wait_for( + same_task(), + timeout=10, + )