diff --git a/.pylintrc b/.pylintrc index d5c72dd..f111466 100644 --- a/.pylintrc +++ b/.pylintrc @@ -11,6 +11,7 @@ disable=duplicate-code, no-self-use, too-few-public-methods, too-many-arguments, + too-many-branches, too-many-locals, too-many-return-statements, too-many-instance-attributes, diff --git a/VERSION b/VERSION index b056f41..3df6a7a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1,2 @@ -0.0.24 +0.0.25 + diff --git a/exabel_data_sdk/client/api/bulk_insert.py b/exabel_data_sdk/client/api/bulk_insert.py new file mode 100644 index 0000000..959033d --- /dev/null +++ b/exabel_data_sdk/client/api/bulk_insert.py @@ -0,0 +1,97 @@ +from concurrent.futures.thread import ThreadPoolExecutor +from time import time +from typing import Callable, Sequence + +from exabel_data_sdk.client.api.data_classes.request_error import ErrorType, RequestError +from exabel_data_sdk.client.api.resource_creation_result import ( + ResourceCreationResult, + ResourceCreationResults, + ResourceCreationStatus, + TResource, +) + + +def _process( + results: ResourceCreationResults[TResource], + resource: TResource, + insert_func: Callable[[TResource], ResourceCreationStatus], +) -> None: + """ + Insert the given resource using the provided function. + Catches and handles RequestErrors. + + Args: + results: the result set to append to + resource: the resource to be inserted + insert_func: the function to use to insert the resource + """ + try: + status = insert_func(resource) + results.add(ResourceCreationResult(status, resource)) + except RequestError as error: + status = ( + ResourceCreationStatus.EXISTS + if error.error_type == ErrorType.ALREADY_EXISTS + else ResourceCreationStatus.FAILED + ) + results.add(ResourceCreationResult(status, resource, error)) + + +def _bulk_insert( + results: ResourceCreationResults[TResource], + resources: Sequence[TResource], + insert_func: Callable[[TResource], ResourceCreationStatus], + threads: int = 40, +) -> None: + """ + Calls the provided insert function with each of the provided resources, + while catching errors and tracking progress. + + Args: + results: add the results to this result set + resources: the resources to be inserted + insert_func: the function to call for each insert. + threads: the number of parallel upload threads to use + """ + if threads == 1: + for resource in resources: + _process(results, resource, insert_func) + else: + with ThreadPoolExecutor(max_workers=threads) as executor: + for resource in resources: + executor.submit(_process, results, resource, insert_func) + + +def bulk_insert( + resources: Sequence[TResource], + insert_func: Callable[[TResource], ResourceCreationStatus], + retries: int = 2, + threads: int = 40, +) -> ResourceCreationResults[TResource]: + """ + Calls the provided insert function with each of the provided resources, + while catching errors and tracking progress. + + Args: + resources: the resources to be inserted + insert_func: the function to call for each insert. + retries: the maximum number of retries to make for each failed request + threads: the number of parallel upload threads to use + + Returns: + the result set showing the current status for each insert + """ + start_time = time() + results: ResourceCreationResults[TResource] = ResourceCreationResults(len(resources)) + for trial in range(retries + 1): + if trial > 0: + failures = results.extract_retryable_failures() + if not failures: + break + resources = [result.resource for result in failures] + print(f"Retry #{trial} with {len(resources)} resources:") + _bulk_insert(results, resources, insert_func, threads=threads) + spent_time = int(time() - start_time) + print(f"Spent {spent_time} seconds loading {len(resources)} resources ({threads} threads)") + results.print_summary() + return results diff --git a/exabel_data_sdk/client/api/data_classes/entity.py b/exabel_data_sdk/client/api/data_classes/entity.py index 9f6fbe6..4180314 100644 --- a/exabel_data_sdk/client/api/data_classes/entity.py +++ b/exabel_data_sdk/client/api/data_classes/entity.py @@ -1,3 +1,4 @@ +import re from typing import Mapping, Union from exabel_data_sdk.client.api.proto_utils import from_struct, to_struct @@ -89,3 +90,16 @@ def __repr__(self) -> str: f"description='{self.description}', properties={self.properties}, " f"read_only={self.read_only})" ) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, Entity): + raise ValueError(f"Cannot compare Entity to non-Entity: {other}") + return self.name < other.name + + def get_entity_type(self) -> str: + """Extracts the entity type name from the entity's resource name.""" + p = re.compile(r"(entityTypes/[a-zA-Z0-9_\-.]+)/entities/[a-zA-Z0-9_\-.]+") + m = p.match(self.name) + if m: + return m.group(1) + raise ValueError(f"Could not parse entity resource name: {self.name}") diff --git a/exabel_data_sdk/client/api/data_classes/request_error.py b/exabel_data_sdk/client/api/data_classes/request_error.py index 8469049..61d37e0 100644 --- a/exabel_data_sdk/client/api/data_classes/request_error.py +++ b/exabel_data_sdk/client/api/data_classes/request_error.py @@ -1,6 +1,7 @@ -from enum import Enum +from enum import Enum, unique +@unique class ErrorType(Enum): """ Error types. @@ -24,6 +25,10 @@ class ErrorType(Enum): # Any internal error. INTERNAL = 10 + def retryable(self) -> bool: + """Return whether it makes sense to retry the request if this error is given.""" + return self in (ErrorType.UNAVAILABLE, ErrorType.TIMEOUT, ErrorType.INTERNAL) + class RequestError(Exception): """ diff --git a/exabel_data_sdk/client/api/entity_api.py b/exabel_data_sdk/client/api/entity_api.py index 695e45b..e772c07 100644 --- a/exabel_data_sdk/client/api/entity_api.py +++ b/exabel_data_sdk/client/api/entity_api.py @@ -1,15 +1,15 @@ -from typing import Callable, Optional, Sequence +from typing import Optional, Sequence from google.protobuf.field_mask_pb2 import FieldMask from exabel_data_sdk.client.api.api_client.grpc.entity_grpc_client import EntityGrpcClient from exabel_data_sdk.client.api.api_client.http.entity_http_client import EntityHttpClient +from exabel_data_sdk.client.api.bulk_insert import bulk_insert from exabel_data_sdk.client.api.data_classes.entity import Entity from exabel_data_sdk.client.api.data_classes.entity_type import EntityType from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult from exabel_data_sdk.client.api.data_classes.request_error import ErrorType, RequestError from exabel_data_sdk.client.api.resource_creation_result import ( - ResourceCreationResult, ResourceCreationResults, ResourceCreationStatus, ) @@ -188,29 +188,19 @@ def bulk_create_entities( self, entities: Sequence[Entity], entity_type: str, - status_callback: Callable[[ResourceCreationResults, int], None] = None, + threads: int = 40, ) -> ResourceCreationResults[Entity]: """ Check if the provided entities exist, and create them if they don't. All entities must be of the given entity_type. If an entity with the given name already exists, it is not updated. + """ + + def insert(entity: Entity) -> ResourceCreationStatus: + # Optimistically insert the entity. + # If the entity already exists, we'll get an ALREADY_EXISTS error from the backend, + # which is handled appropriately by the bulk_insert function. + self.create_entity(entity=entity, entity_type=entity_type) + return ResourceCreationStatus.CREATED - Optionally, a callback can be provided to track the progress. - The callback is called after every 10th entity is processed. - """ - results: ResourceCreationResults[Entity] = ResourceCreationResults() - for entity in entities: - try: - existing_entity = self.get_entity(entity.name) - if existing_entity is None: - new_entity = self.create_entity(entity=entity, entity_type=entity_type) - results.add(ResourceCreationResult(ResourceCreationStatus.CREATED, new_entity)) - else: - results.add( - ResourceCreationResult(ResourceCreationStatus.EXISTS, existing_entity) - ) - except RequestError as error: - results.add(ResourceCreationResult(ResourceCreationStatus.FAILED, entity, error)) - if status_callback and (results.count() % 10 == 0 or results.count() == len(entities)): - status_callback(results, len(entities)) - return results + return bulk_insert(entities, insert, threads=threads) diff --git a/exabel_data_sdk/client/api/relationship_api.py b/exabel_data_sdk/client/api/relationship_api.py index 170f7bd..7f980af 100644 --- a/exabel_data_sdk/client/api/relationship_api.py +++ b/exabel_data_sdk/client/api/relationship_api.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Optional, Sequence from google.protobuf.field_mask_pb2 import FieldMask @@ -8,12 +8,12 @@ from exabel_data_sdk.client.api.api_client.http.relationship_http_client import ( RelationshipHttpClient, ) +from exabel_data_sdk.client.api.bulk_insert import bulk_insert from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult from exabel_data_sdk.client.api.data_classes.relationship import Relationship from exabel_data_sdk.client.api.data_classes.relationship_type import RelationshipType from exabel_data_sdk.client.api.data_classes.request_error import ErrorType, RequestError from exabel_data_sdk.client.api.resource_creation_result import ( - ResourceCreationResult, ResourceCreationResults, ResourceCreationStatus, ) @@ -276,37 +276,18 @@ def relationship_exists(self, relationship_type: str, from_entity: str, to_entit def bulk_create_relationships( self, relationships: Sequence[Relationship], - status_callback: Callable[[ResourceCreationResults, int], None] = None, + threads: int = 40, ) -> ResourceCreationResults[Relationship]: """ Check if the provided relationships exist, and create them if they don't. If the relationship already exists, it is not updated. - - Optionally, a callback can be provided to track the progress. - The callback is called after every 10th relationship is processed. """ - total = len(relationships) - results: ResourceCreationResults[Relationship] = ResourceCreationResults() - for relationship in relationships: - try: - existing_relationship = self.get_relationship( - relationship_type=relationship.relationship_type, - from_entity=relationship.from_entity, - to_entity=relationship.to_entity, - ) - if existing_relationship is None: - new_relationship = self.create_relationship(relationship=relationship) - results.add( - ResourceCreationResult(ResourceCreationStatus.CREATED, new_relationship) - ) - else: - results.add( - ResourceCreationResult(ResourceCreationStatus.EXISTS, existing_relationship) - ) - except RequestError as error: - results.add( - ResourceCreationResult(ResourceCreationStatus.FAILED, relationship, error) - ) - if status_callback and (results.count() % 10 == 0 or results.count() == total): - status_callback(results, total) - return results + + def insert(relationship: Relationship) -> ResourceCreationStatus: + # Optimistically insert the relationship. + # If the relationship already exists, we'll get an ALREADY_EXISTS error from the + # backend, which is handled appropriately by the bulk_insert function. + self.create_relationship(relationship=relationship) + return ResourceCreationStatus.CREATED + + return bulk_insert(relationships, insert, threads=threads) diff --git a/exabel_data_sdk/client/api/resource_creation_result.py b/exabel_data_sdk/client/api/resource_creation_result.py index 875b88a..545efdf 100644 --- a/exabel_data_sdk/client/api/resource_creation_result.py +++ b/exabel_data_sdk/client/api/resource_creation_result.py @@ -48,14 +48,29 @@ class ResourceCreationResults(Generic[TResource]): Class for returning resource creation results. """ - def __init__(self) -> None: + def __init__( + self, total_count: int, print_status: bool = True, abort_threshold: float = 0.5 + ) -> None: + """ + Args: + total_count: The total number of resources expected to be loaded. + print_status: Whether to print status of the upload during processing. + abort_threshold: If the fraction of failed requests exceeds this threshold, + the upload is aborted, and the script exits. + Note that this only happens if print_status is set to True. + """ self.results: List[ResourceCreationResult[TResource]] = [] self.counter: Counter = Counter() + self.total_count = total_count + self.do_print_status = print_status + self.abort_threshold = abort_threshold def add(self, result: ResourceCreationResult[TResource]) -> None: """Add the result for a resource.""" self.results.append(result) self.counter.update([result.status]) + if self.do_print_status and (self.count() % 20 == 0 or self.count() == self.total_count): + self.print_status() def count(self, status: ResourceCreationStatus = None) -> int: """ @@ -64,37 +79,62 @@ def count(self, status: ResourceCreationStatus = None) -> int: """ return len(self.results) if status is None else self.counter[status] + def extract_retryable_failures(self) -> List[ResourceCreationResult[TResource]]: + """ + Remove all retryable failures from this result set, + and return them. + """ + failed = [] + rest = [] + for result in self.results: + if ( + result.status == ResourceCreationStatus.FAILED + and result.error + and result.error.error_type.retryable() + ): + failed.append(result) + else: + rest.append(result) + self.counter.subtract([result.status for result in failed]) + self.results = rest + return failed + def print_summary(self) -> None: """Prints a human legible summary of the resource creation results to screen.""" print(self.counter[ResourceCreationStatus.CREATED], "new resources created") - print(self.counter[ResourceCreationStatus.EXISTS], "resources already existed") + if self.counter[ResourceCreationStatus.EXISTS]: + print(self.counter[ResourceCreationStatus.EXISTS], "resources already existed") if self.counter[ResourceCreationStatus.FAILED]: print(self.counter[ResourceCreationStatus.FAILED], "resources failed:") for result in self.results: if result.status == ResourceCreationStatus.FAILED: print(" ", result.resource, ":\n ", result.error) + def print_status(self) -> None: + """ + Prints a status update on the progress of the data loading, showing the percentage complete + and how many objects were created, already existed or failed. -def status_callback(results: ResourceCreationResults, total_count: int) -> None: - """ - Prints a status update on the progress of the data loading, showing the percentage complete - and how many objects were created, already existed or failed. - - Note that the previous status message is overwritten (by writing '\r'), - but this only works if nothing else has been printed to stdout since the last update. - """ - fraction_complete = results.count() / total_count - sys.stdout.write( - f"\r{fraction_complete:.0%} - " - f"{results.count(ResourceCreationStatus.CREATED)} created, " - f"{results.count(ResourceCreationStatus.EXISTS)} exists, " - f"{results.count(ResourceCreationStatus.FAILED)} failed" - ) - if fraction_complete == 1: - sys.stdout.write("\n") - fraction_error = results.count(ResourceCreationStatus.FAILED) / results.count() - if fraction_error > 0.5: - sys.stdout.write("\nAborting - more than half the requests are failing.\n") - results.print_summary() - sys.exit(-1) - sys.stdout.flush() + Note that the previous status message is overwritten (by writing '\r'), + but this only works if nothing else has been printed to stdout since the last update. + """ + fraction_complete = self.count() / self.total_count + sys.stdout.write( + f"\r{fraction_complete:.0%} - " + f"{self.count(ResourceCreationStatus.CREATED)} created, " + f"{self.count(ResourceCreationStatus.EXISTS)} exists, " + f"{self.count(ResourceCreationStatus.FAILED)} failed" + ) + if fraction_complete == 1: + sys.stdout.write("\n") + else: + fraction_error = self.count(ResourceCreationStatus.FAILED) / self.count() + if fraction_error > self.abort_threshold: + sys.stdout.write( + f"\nAborting - more than {self.abort_threshold:.0%} " + "of the requests are failing.\n" + ) + self.print_summary() + sys.stdout.flush() + sys.exit(1) + sys.stdout.flush() diff --git a/exabel_data_sdk/client/api/time_series_api.py b/exabel_data_sdk/client/api/time_series_api.py index 9b8353d..65de333 100644 --- a/exabel_data_sdk/client/api/time_series_api.py +++ b/exabel_data_sdk/client/api/time_series_api.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Sequence +from typing import Optional, Sequence import pandas as pd from dateutil import tz @@ -7,10 +7,10 @@ from exabel_data_sdk.client.api.api_client.grpc.time_series_grpc_client import TimeSeriesGrpcClient from exabel_data_sdk.client.api.api_client.http.time_series_http_client import TimeSeriesHttpClient +from exabel_data_sdk.client.api.bulk_insert import bulk_insert from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult from exabel_data_sdk.client.api.data_classes.request_error import ErrorType, RequestError from exabel_data_sdk.client.api.resource_creation_result import ( - ResourceCreationResult, ResourceCreationResults, ResourceCreationStatus, ) @@ -164,11 +164,16 @@ def upsert_time_series(self, name: str, series: pd.Series, create_tag: bool = Fa Returns: True if the time series already existed, or False if it is created """ - if self.time_series_exists(name): + try: + # Optimistically assume that the time series exists, and append to it. + # If it doesn't exist, we catch the error below and create the time series instead. self.append_time_series_data(name, series) return True - self.create_time_series(name, series, create_tag) - return False + except RequestError as error: + if error.error_type == ErrorType.NOT_FOUND: + self.create_time_series(name, series, create_tag) + return False + raise def clear_time_series_data(self, name: str, start: pd.Timestamp, end: pd.Timestamp) -> None: """ @@ -184,13 +189,32 @@ def clear_time_series_data(self, name: str, start: pd.Timestamp, end: pd.Timesta BatchDeleteTimeSeriesPointsRequest(name=name, time_ranges=[time_range]), ) - def append_time_series_data(self, name: str, series: pd.Series) -> pd.Series: + def append_time_series_data(self, name: str, series: pd.Series) -> None: """ Append data to the given time series. If the given series contains data points that already exist, these data points will be overwritten. + Args: + name: The resource name of the time series. + series: Series with data to append. + """ + self.client.update_time_series( + UpdateTimeSeriesRequest( + time_series=ProtoTimeSeries( + name=name, points=self._series_to_time_series_points(series) + ), + ), + ) + + def append_time_series_data_and_return(self, name: str, series: pd.Series) -> pd.Series: + """ + Append data to the given time series, and return the full series. + + If the given series contains data points that already exist, these data points will be + overwritten. + Args: name: The resource name of the time series. series: Series with data to append. @@ -198,14 +222,12 @@ def append_time_series_data(self, name: str, series: pd.Series) -> pd.Series: Returns: A series with all data for the given time series. """ - proto_time_series = ProtoTimeSeries( - name=name, points=self._series_to_time_series_points(series) - ) - # Set empty TimeRange() in request to get back entire time series. time_series = self.client.update_time_series( UpdateTimeSeriesRequest( - time_series=proto_time_series, + time_series=ProtoTimeSeries( + name=name, points=self._series_to_time_series_points(series) + ), view=TimeSeriesView(time_range=TimeRange()), ), ) @@ -236,7 +258,7 @@ def bulk_upsert_time_series( self, series: Sequence[pd.Series], create_tag: bool = False, - status_callback: Callable[[ResourceCreationResults, int], None] = None, + threads: int = 40, ) -> ResourceCreationResults[pd.Series]: """ Calls upsert_time_series for each of the provided time series, @@ -246,25 +268,18 @@ def bulk_upsert_time_series( See the docstring of upsert_time_series regarding required format for this resource name. Args: - series: the time series to be inserted - create_tag: Set to true to create a tag for every entity type a signal has time series - for. If a tag already exists, it will be updated when time series are - created (or deleted) regardless of the value of this flag. - status_callback: Called after every 10th time series is processed, to track progress. + series: The time series to be inserted + create_tag: Set to true to create a tag for every entity type a signal has time + series for. If a tag already exists, it will be updated when time + series are created (or deleted) regardless of the value of this flag. + threads: The number of parallel upload threads to use. """ - results: ResourceCreationResults[pd.Series] = ResourceCreationResults() - for ts in series: - try: - existed = self.upsert_time_series(str(ts.name), ts, create_tag=create_tag) - status = ( - ResourceCreationStatus.EXISTS if existed else ResourceCreationStatus.CREATED - ) - results.add(ResourceCreationResult(status, ts)) - except RequestError as error: - results.add(ResourceCreationResult(ResourceCreationStatus.FAILED, ts, error)) - if status_callback and (results.count() % 10 == 0 or results.count() == len(series)): - status_callback(results, len(series)) - return results + + def insert(ts: pd.Series) -> ResourceCreationStatus: + existed = self.upsert_time_series(str(ts.name), ts, create_tag=create_tag) + return ResourceCreationStatus.EXISTS if existed else ResourceCreationStatus.CREATED + + return bulk_insert(series, insert, threads=threads) @staticmethod def _series_to_time_series_points(series: pd.Series) -> Sequence[TimeSeriesPoint]: diff --git a/exabel_data_sdk/scripts/csv_script.py b/exabel_data_sdk/scripts/csv_script.py index 31287e2..9833692 100644 --- a/exabel_data_sdk/scripts/csv_script.py +++ b/exabel_data_sdk/scripts/csv_script.py @@ -1,6 +1,6 @@ import argparse import os -from typing import Sequence +from typing import Collection, Mapping, Optional, Sequence, Union import pandas as pd @@ -46,7 +46,21 @@ def __init__(self, argv: Sequence[str], description: str): default=namespace, help=help_text, ) + self.parser.add_argument( + "--threads", + required=False, + type=int, + choices=range(1, 101), + metavar="[1-100]", + default=40, + help="The number of parallel upload threads to run. Defaults to 40.", + ) - def read_csv(self, args: argparse.Namespace) -> pd.DataFrame: + def read_csv( + self, args: argparse.Namespace, string_columns: Collection[Union[str, int]] = None + ) -> pd.DataFrame: """Read the CSV file from disk with the filename specified by command line argument.""" - return pd.read_csv(args.filename, header=0, sep=args.sep) + dtype: Optional[Mapping[Union[str, int], type]] = None + if string_columns: + dtype = {column: str for column in string_columns} + return pd.read_csv(args.filename, header=0, sep=args.sep, dtype=dtype) diff --git a/exabel_data_sdk/scripts/load_entities_from_csv.py b/exabel_data_sdk/scripts/load_entities_from_csv.py index f831673..3b3d3f2 100644 --- a/exabel_data_sdk/scripts/load_entities_from_csv.py +++ b/exabel_data_sdk/scripts/load_entities_from_csv.py @@ -4,7 +4,6 @@ from exabel_data_sdk import ExabelClient from exabel_data_sdk.client.api.data_classes.entity import Entity -from exabel_data_sdk.client.api.resource_creation_result import status_callback from exabel_data_sdk.scripts.csv_script import CsvScript from exabel_data_sdk.util.resource_name_normalization import normalize_resource_name @@ -64,7 +63,14 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: print("Running dry-run...") print("Loading entities from", args.filename) - entities_df = self.read_csv(args) + name_col_ref = args.name_column or 0 + string_columns = { + name_col_ref, + args.display_name_column or name_col_ref, + } + if args.description_column: + string_columns.add(args.description_column) + entities_df = self.read_csv(args, string_columns=string_columns) name_col = args.name_column or entities_df.columns[0] display_name_col = args.display_name_column or name_col @@ -93,10 +99,7 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: print(entities) return - results = client.entity_api.bulk_create_entities( - entities, entity_type_name, status_callback - ) - results.print_summary() + client.entity_api.bulk_create_entities(entities, entity_type_name, threads=args.threads) if __name__ == "__main__": diff --git a/exabel_data_sdk/scripts/load_relationships_from_csv.py b/exabel_data_sdk/scripts/load_relationships_from_csv.py index 25adea6..04a88a2 100644 --- a/exabel_data_sdk/scripts/load_relationships_from_csv.py +++ b/exabel_data_sdk/scripts/load_relationships_from_csv.py @@ -5,7 +5,6 @@ from exabel_data_sdk import ExabelClient from exabel_data_sdk.client.api.data_classes.relationship import Relationship from exabel_data_sdk.client.api.data_classes.relationship_type import RelationshipType -from exabel_data_sdk.client.api.resource_creation_result import status_callback from exabel_data_sdk.scripts.csv_script import CsvScript from exabel_data_sdk.util.resource_name_normalization import to_entity_resource_names @@ -76,13 +75,20 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: f"to {args.entity_to_column} from {args.filename}" ) - relationships_df = self.read_csv(args) + string_columns = { + args.entity_from_column, + args.entity_to_column, + } + if args.description_column: + string_columns.add(args.description_column) + + relationships_df = self.read_csv(args, string_columns=string_columns) entity_from_col = args.entity_from_column entity_to_col = args.entity_to_column description_col = args.description_column - relationship_type_name = f"relationshipTypes/{args.relationship_type}" + relationship_type_name = f"relationshipTypes/{args.namespace}.{args.relationship_type}" relationship_type = client.relationship_api.get_relationship_type(relationship_type_name) if not relationship_type: @@ -90,7 +96,8 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: relationship_type = RelationshipType(name=relationship_type_name) client.relationship_api.create_relationship_type(relationship_type) print("Available relationship types are:") - print(client.relationship_api.list_relationship_types()) + for rel_type in client.relationship_api.list_relationship_types().results: + print(" ", rel_type) relationships_df[entity_from_col] = to_entity_resource_names( client.entity_api, relationships_df[entity_from_col], namespace=args.namespace @@ -117,8 +124,7 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: print(relationships) return - results = client.relationship_api.bulk_create_relationships(relationships, status_callback) - results.print_summary() + client.relationship_api.bulk_create_relationships(relationships, threads=args.threads) if __name__ == "__main__": diff --git a/exabel_data_sdk/scripts/load_time_series_from_csv.py b/exabel_data_sdk/scripts/load_time_series_from_csv.py index 70536ef..62251ce 100644 --- a/exabel_data_sdk/scripts/load_time_series_from_csv.py +++ b/exabel_data_sdk/scripts/load_time_series_from_csv.py @@ -7,7 +7,6 @@ from exabel_data_sdk import ExabelClient from exabel_data_sdk.client.api.data_classes.signal import Signal -from exabel_data_sdk.client.api.resource_creation_result import status_callback from exabel_data_sdk.scripts.csv_script import CsvScript from exabel_data_sdk.util.resource_name_normalization import to_entity_resource_names @@ -61,7 +60,7 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: if args.dry_run: print("Running dry-run...") - ts_data = self.read_csv(args) + ts_data = self.read_csv(args, string_columns=[0]) ts_data.iloc[:, 0] = to_entity_resource_names( client.entity_api, ts_data.iloc[:, 0], namespace=args.namespace @@ -109,10 +108,9 @@ def run_script(self, client: ExabelClient, args: argparse.Namespace) -> None: print(f" {ts.name}") return - results = client.time_series_api.bulk_upsert_time_series( - series, create_tag=True, status_callback=status_callback + client.time_series_api.bulk_upsert_time_series( + series, create_tag=True, threads=args.threads ) - results.print_summary() if __name__ == "__main__": diff --git a/exabel_data_sdk/tests/client/api/mock_entity_api.py b/exabel_data_sdk/tests/client/api/mock_entity_api.py new file mode 100644 index 0000000..3207a5f --- /dev/null +++ b/exabel_data_sdk/tests/client/api/mock_entity_api.py @@ -0,0 +1,54 @@ +from typing import Optional, Sequence + +from google.protobuf.field_mask_pb2 import FieldMask + +from exabel_data_sdk.client.api.data_classes.entity import Entity +from exabel_data_sdk.client.api.data_classes.entity_type import EntityType +from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult +from exabel_data_sdk.client.api.entity_api import EntityApi +from exabel_data_sdk.tests.client.api.mock_resource_store import MockResourceStore + + +# pylint: disable=super-init-not-called +class MockEntityApi(EntityApi): + """ + Mock of the EntityApi class for CRUD operations on entities and entity types. + """ + + def __init__(self): + self.entities = MockResourceStore() + self.types = MockResourceStore() + self._insert_standard_entity_types() + + def _insert_standard_entity_types(self): + for entity_type in ("brand", "business_segment", "company", "country", "region"): + self.types.create(EntityType("entityTypes/" + entity_type, entity_type, "")) + + def list_entity_types( + self, page_size: int = 1000, page_token: str = None + ) -> PagingResult[EntityType]: + return self.types.list() + + def get_entity_type(self, name: str) -> Optional[EntityType]: + return self.types.get(name) + + def list_entities( + self, entity_type: str, page_size: int = 1000, page_token: str = None + ) -> PagingResult[Entity]: + return self.entities.list(lambda x: x.get_entity_type() == entity_type) + + def get_entity(self, name: str) -> Optional[Entity]: + return self.entities.get(name) + + def create_entity(self, entity: Entity, entity_type: str) -> Entity: + return self.entities.create(entity) + + def update_entity(self, entity: Entity, update_mask: FieldMask = None) -> Entity: + raise NotImplementedError() + + def delete_entity(self, name: str) -> None: + # Note: The mock implementation does not delete associated time series and relationships + self.entities.delete(name) + + def search_for_entities(self, entity_type: str, **search_terms: str) -> Sequence[Entity]: + raise NotImplementedError() diff --git a/exabel_data_sdk/tests/client/api/mock_relationship_api.py b/exabel_data_sdk/tests/client/api/mock_relationship_api.py new file mode 100644 index 0000000..fa160db --- /dev/null +++ b/exabel_data_sdk/tests/client/api/mock_relationship_api.py @@ -0,0 +1,89 @@ +from typing import Optional + +from google.protobuf.field_mask_pb2 import FieldMask + +from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult +from exabel_data_sdk.client.api.data_classes.relationship import Relationship +from exabel_data_sdk.client.api.data_classes.relationship_type import RelationshipType +from exabel_data_sdk.client.api.relationship_api import RelationshipApi +from exabel_data_sdk.tests.client.api.mock_resource_store import MockResourceStore + + +# pylint: disable=super-init-not-called +class MockRelationshipApi(RelationshipApi): + """ + Mock of the RelationshipApi class for CRUD operations on relationships and relationship types. + """ + + def __init__(self): + self.relationships = MockResourceStore() + self.types = MockResourceStore() + self._insert_standard_relationship_types() + + def _insert_standard_relationship_types(self): + for rel_type in ("LOCATED_IN", "WEB_DOMAIN_OWNED_BY"): + self.types.create(RelationshipType("relationshipTypes/" + rel_type, rel_type, "")) + + @staticmethod + def _key(relationship: Relationship) -> object: + return (relationship.relationship_type, relationship.from_entity, relationship.to_entity) + + def list_relationship_types( + self, page_size: int = 1000, page_token: str = None + ) -> PagingResult[RelationshipType]: + return self.types.list() + + def get_relationship_type(self, name: str) -> Optional[RelationshipType]: + return self.types.get(name) + + def create_relationship_type(self, relationship_type: RelationshipType) -> RelationshipType: + return self.types.create(relationship_type) + + def update_relationship_type( + self, relationship_type: RelationshipType, update_mask: FieldMask = None + ) -> RelationshipType: + raise NotImplementedError() + + def delete_relationship_type(self, relationship_type: str) -> None: + self.types.delete(relationship_type) + + def get_relationships_from_entity( + self, + relationship_type: str, + from_entity: str, + page_size: int = 1000, + page_token: str = None, + ) -> PagingResult[Relationship]: + raise NotImplementedError() + + def get_relationships_to_entity( + self, + relationship_type: str, + to_entity: str, + page_size: int = 1000, + page_token: str = None, + ) -> PagingResult[Relationship]: + raise NotImplementedError() + + def get_relationship( + self, relationship_type: str, from_entity: str, to_entity: str + ) -> Optional[Relationship]: + return self.relationships.get((relationship_type, from_entity, to_entity)) + + def create_relationship(self, relationship: Relationship) -> Relationship: + return self.relationships.create(relationship, self._key(relationship)) + + def update_relationship( + self, relationship: Relationship, update_mask: FieldMask = None + ) -> Relationship: + raise NotImplementedError() + + def delete_relationship(self, relationship_type: str, from_entity: str, to_entity: str) -> None: + self.relationships.delete((relationship_type, from_entity, to_entity)) + + def list_relationships(self) -> PagingResult[Relationship]: + """ + Returns all relationships. + Note that this method is only available in the mock API, not the real API. + """ + return self.relationships.list() diff --git a/exabel_data_sdk/tests/client/api/mock_resource_store.py b/exabel_data_sdk/tests/client/api/mock_resource_store.py new file mode 100644 index 0000000..b72ff29 --- /dev/null +++ b/exabel_data_sdk/tests/client/api/mock_resource_store.py @@ -0,0 +1,71 @@ +from random import random +from typing import Callable, Dict, Generic, Optional, TypeVar + +from exabel_data_sdk.client.api.data_classes.paging_result import PagingResult +from exabel_data_sdk.client.api.data_classes.request_error import ErrorType, RequestError + +TResource = TypeVar("TResource") + + +def failure_prone(func): + """ + Decorator for class methods that make them raise an exception every now and then, + depending on the 'failure_rate' attribute of the object. + """ + + def unreliable(self, *args, **kwargs): + if self.failure_rate and random() < self.failure_rate: + raise RequestError(ErrorType.UNAVAILABLE, "This is a random failure") + return func(self, *args, **kwargs) + + return unreliable + + +class MockResourceStore(Generic[TResource]): + """In-memory resource store. Only intended for tests.""" + + def __init__(self): + self.resources: Dict[object, TResource] = {} + # The failure rate, as a fraction (0.0-1.0) of calls that should fail + self.failure_rate = 0.0 + + def get(self, key: object) -> Optional[TResource]: + """Get the resource with the given key if present, otherwise returns None.""" + return self.resources.get(key, None) + + def list(self, predicate: Callable[[TResource], bool] = None) -> PagingResult[TResource]: + """List all resources in the store.""" + resources = list(self.resources.values()) + if predicate: + resources = list(filter(predicate, resources)) + return PagingResult( + results=resources, + next_page_token="next_page_token", + total_size=len(resources), + ) + + @failure_prone + def create(self, resource: TResource, key: object = None) -> TResource: + """ + Create the given resource in the store. + + Typically, resources have a resource name, which should be used as the key. + In this case, the key parameter should not be set. + If not, as in the case of Relationship resources, an explicit key must be provided. + + Args: + resource: the resource to create in the store + key: the key of the resource. Defaults to the resource's name. + """ + if key is None: + key = resource.name # type: ignore[attr-defined] + if key in self.resources: + raise RequestError(ErrorType.ALREADY_EXISTS, f"Already exists: {key}") + self.resources[key] = resource + return resource + + def delete(self, key: object): + """Delete the resource with the given key.""" + if key in self.resources: + del self.resources[key] + raise ValueError(f"Trying to delete non-existent resource: {key}") diff --git a/exabel_data_sdk/tests/client/exabel_mock_client.py b/exabel_data_sdk/tests/client/exabel_mock_client.py new file mode 100644 index 0000000..b0e083a --- /dev/null +++ b/exabel_data_sdk/tests/client/exabel_mock_client.py @@ -0,0 +1,15 @@ +from exabel_data_sdk import ExabelClient +from exabel_data_sdk.tests.client.api.mock_entity_api import MockEntityApi +from exabel_data_sdk.tests.client.api.mock_relationship_api import MockRelationshipApi + + +# pylint: disable=super-init-not-called +class ExabelMockClient(ExabelClient): + """ + Mock of the ExabelClient that uses mock implementations of the API classes, + which only store objects in memory. + """ + + def __init__(self): + self.entity_api = MockEntityApi() + self.relationship_api = MockRelationshipApi() diff --git a/exabel_data_sdk/tests/resources/data/entities.csv b/exabel_data_sdk/tests/resources/data/entities.csv new file mode 100644 index 0000000..d6646dc --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/entities.csv @@ -0,0 +1,4 @@ +brand,description +Spring & Vine,Shampoo bars +The Coconut Tree,Sri Lankan street food +Spring & Vine,This entry will be ignored because it's a duplicate diff --git a/exabel_data_sdk/tests/resources/data/entities2.csv b/exabel_data_sdk/tests/resources/data/entities2.csv new file mode 100644 index 0000000..9bf1026 --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/entities2.csv @@ -0,0 +1,11 @@ +brand +Brand A +Brand B +Brand C +Brand D +Brand E +Brand F +Brand G +Brand H +Brand I +Brand J diff --git a/exabel_data_sdk/tests/resources/data/entities_with_integer_identifiers.csv b/exabel_data_sdk/tests/resources/data/entities_with_integer_identifiers.csv new file mode 100644 index 0000000..e1c774f --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/entities_with_integer_identifiers.csv @@ -0,0 +1,3 @@ +brand,brand_name,description +0001,Spring & Vine,Shampoo bars +0002,The Coconut Tree,Sri Lankan street food diff --git a/exabel_data_sdk/tests/resources/data/relationships.csv b/exabel_data_sdk/tests/resources/data/relationships.csv new file mode 100644 index 0000000..60b99ce --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/relationships.csv @@ -0,0 +1,4 @@ +entity_from,brand,description +entityTypes/company/company_x,Spring & Vine,Owned since 2019 +entityTypes/company/company_x,Spring & Vine,This entry will be ignored because it's a duplicate +entityTypes/company/company_y,The Coconut Tree,Acquired for $200M diff --git a/exabel_data_sdk/tests/resources/data/relationships_with_integer_identifiers.csv b/exabel_data_sdk/tests/resources/data/relationships_with_integer_identifiers.csv new file mode 100644 index 0000000..a29ee88 --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/relationships_with_integer_identifiers.csv @@ -0,0 +1,3 @@ +company,brand,description +0010,0001,Owned since 2019 +0011,0002,Acquired for $200M diff --git a/exabel_data_sdk/tests/resources/data/timeseries_with_integer_identifiers.csv b/exabel_data_sdk/tests/resources/data/timeseries_with_integer_identifiers.csv new file mode 100644 index 0000000..ccebe30 --- /dev/null +++ b/exabel_data_sdk/tests/resources/data/timeseries_with_integer_identifiers.csv @@ -0,0 +1,8 @@ +brand;date;signal1 +0001;2021-01-01;1 +0001;2021-01-02;2 +0001;2021-01-03;3 +0001;2021-01-04;4 +0001;2021-01-05;5 +0002;2021-01-01;4 +0002;2021-01-03;5 diff --git a/exabel_data_sdk/tests/scripts/common_utils.py b/exabel_data_sdk/tests/scripts/common_utils.py new file mode 100644 index 0000000..8678dce --- /dev/null +++ b/exabel_data_sdk/tests/scripts/common_utils.py @@ -0,0 +1,14 @@ +from typing import Sequence, Type + +from exabel_data_sdk.client.exabel_client import ExabelClient +from exabel_data_sdk.scripts.csv_script import CsvScript +from exabel_data_sdk.tests.client.exabel_mock_client import ExabelMockClient + + +def load_test_data_from_csv(csv_script: Type[CsvScript], args: Sequence[str]) -> ExabelClient: + """Loads entities to an ExabelMockClient using exabel_data_sdk.scripts.load_entities_from_csv""" + script = csv_script(args, f"Test{type(csv_script).__name__}") + client = ExabelMockClient() + script.run_script(client, script.parse_arguments()) + + return client diff --git a/exabel_data_sdk/tests/scripts/test_load_entities_from_csv.py b/exabel_data_sdk/tests/scripts/test_load_entities_from_csv.py new file mode 100644 index 0000000..a5c5b99 --- /dev/null +++ b/exabel_data_sdk/tests/scripts/test_load_entities_from_csv.py @@ -0,0 +1,78 @@ +import random +import unittest + +from exabel_data_sdk.client.api.data_classes.entity import Entity +from exabel_data_sdk.scripts.load_entities_from_csv import LoadEntitiesFromCsv +from exabel_data_sdk.tests.scripts.common_utils import load_test_data_from_csv + +common_args = [ + "script-name", + "--namespace", + "test", + "--api-key", + "123", +] + + +class TestLoadEntities(unittest.TestCase): + def test_read_file(self): + args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/entities.csv", + "--description_col", + "description", + ] + client = load_test_data_from_csv(LoadEntitiesFromCsv, args) + expected_entities = [ + Entity( + name="entityTypes/brand/entities/test.Spring_Vine", + display_name="Spring & Vine", + description="Shampoo bars", + ), + Entity( + name="entityTypes/brand/entities/test.The_Coconut_Tree", + display_name="The Coconut Tree", + description="Sri Lankan street food", + ), + ] + self.check_entities(client, expected_entities) + + def test_read_file_with_integer_identifier(self): + file_args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/entities_with_integer_identifiers.csv", + ] + extra_args = [[], ["--name_column", "brand"]] + expected_entities = [ + Entity(name="entityTypes/brand/entities/test.0001", display_name="0001"), + Entity(name="entityTypes/brand/entities/test.0002", display_name="0002"), + ] + for e_args in extra_args: + args = file_args + e_args + client = load_test_data_from_csv(LoadEntitiesFromCsv, args) + self.check_entities(client, expected_entities) + + def check_entities(self, client, expected_entities): + """Check expected entities against actual entities retrieved from the client""" + all_entities = client.entity_api.list_entities("entityTypes/brand").results + self.assertListEqual(sorted(expected_entities), sorted(all_entities)) + for expected_entity in expected_entities: + entity = client.entity_api.get_entity(expected_entity.name) + self.assertEqual(expected_entity, entity) + + def test_read_file_random_errors(self): + random.seed(1) + args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/entities2.csv", + ] + client = load_test_data_from_csv(LoadEntitiesFromCsv, args) + client.entity_api.entities.failure_rate = 0.3 + expected_entities = [ + Entity( + name=f"entityTypes/brand/entities/test.Brand_{letter}", + display_name=f"Brand {letter}", + ) + for letter in "ABCDEFGHIJ" + ] + self.check_entities(client, expected_entities) diff --git a/exabel_data_sdk/tests/scripts/test_load_relationships_from_csv.py b/exabel_data_sdk/tests/scripts/test_load_relationships_from_csv.py new file mode 100644 index 0000000..a8bffd9 --- /dev/null +++ b/exabel_data_sdk/tests/scripts/test_load_relationships_from_csv.py @@ -0,0 +1,85 @@ +import unittest + +from exabel_data_sdk.client.api.data_classes.relationship import Relationship +from exabel_data_sdk.client.api.data_classes.relationship_type import RelationshipType +from exabel_data_sdk.scripts.load_relationships_from_csv import LoadRelationshipsFromCsv +from exabel_data_sdk.tests.scripts.common_utils import load_test_data_from_csv + +common_args = [ + "script-name", + "--namespace", + "acme", + "--api-key", + "123", + "--relationship_type", + "PART_OF", + "--entity_to_column", + "brand", +] + + +class TestLoadRelationships(unittest.TestCase): + def test_read_file(self): + args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/relationships.csv", + "--entity_from_column", + "entity_from", + "--description_column", + "description", + ] + client = load_test_data_from_csv(LoadRelationshipsFromCsv, args) + # Check that the relationship type was created + self.assertEqual( + RelationshipType("relationshipTypes/acme.PART_OF"), + client.relationship_api.get_relationship_type("relationshipTypes/acme.PART_OF"), + ) + expected_relationships = [ + Relationship( + relationship_type="relationshipTypes/acme.PART_OF", + from_entity="entityTypes/company/company_x", + to_entity="entityTypes/brand/entities/acme.Spring_Vine", + description="Owned since 2019", + ), + Relationship( + relationship_type="relationshipTypes/acme.PART_OF", + from_entity="entityTypes/company/company_y", + to_entity="entityTypes/brand/entities/acme.The_Coconut_Tree", + description="Acquired for $200M", + ), + ] + self.check_relationships(client, expected_relationships) + + def test_read_file_with_integer_identifiers(self): + args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/relationships_with_integer_identifiers.csv", + "--entity_from_column", + "company", + ] + client = load_test_data_from_csv(LoadRelationshipsFromCsv, args) + expected_relationships = [ + Relationship( + relationship_type="relationshipTypes/acme.PART_OF", + from_entity="entityTypes/company/entities/acme.0010", + to_entity="entityTypes/brand/entities/acme.0001", + ), + Relationship( + relationship_type="relationshipTypes/acme.PART_OF", + from_entity="entityTypes/company/entities/acme.0011", + to_entity="entityTypes/brand/entities/acme.0002", + ), + ] + self.check_relationships(client, expected_relationships) + + def check_relationships(self, client, expected_relationships): + """Check expected entities against actual entities retrieved from the client""" + all_relationships = client.relationship_api.list_relationships().results + self.assertListEqual(expected_relationships, all_relationships) + for expected_relationship in expected_relationships: + relationship = client.relationship_api.get_relationship( + expected_relationship.relationship_type, + expected_relationship.from_entity, + expected_relationship.to_entity, + ) + self.assertEqual(expected_relationship, relationship) diff --git a/exabel_data_sdk/tests/scripts/test_load_time_series_from_csv.py b/exabel_data_sdk/tests/scripts/test_load_time_series_from_csv.py index b797f95..a2825da 100644 --- a/exabel_data_sdk/tests/scripts/test_load_time_series_from_csv.py +++ b/exabel_data_sdk/tests/scripts/test_load_time_series_from_csv.py @@ -8,6 +8,14 @@ from exabel_data_sdk import ExabelClient from exabel_data_sdk.scripts.load_time_series_from_csv import LoadTimeSeriesFromCsv +common_args = [ + "script-name", + "--sep", + ";", + "--api-key", + "123", +] + class TestUploadTimeSeries(unittest.TestCase): def test_one_signal(self): @@ -80,16 +88,11 @@ def test_two_signals(self): ) def test_read_file_use_header_for_signal(self): - args = [ - "script-name", + args = common_args + [ "--filename", "./exabel_data_sdk/tests/resources/data/timeseries.csv", - "--sep", - ";", "--namespace", "", - "--api-key", - "123", ] script = LoadTimeSeriesFromCsv(args, "LoadTest1") @@ -121,16 +124,11 @@ def test_read_file_use_header_for_signal(self): ) def test_read_file_with_multiple_signals(self): - args = [ - "script-name", + args = common_args + [ "--filename", "./exabel_data_sdk/tests/resources/data/timeseries_multiple_signals.csv", - "--sep", - ";", "--namespace", "acme", - "--api-key", - "123", ] script = LoadTimeSeriesFromCsv(args, "LoadTest3") client = mock.create_autospec(ExabelClient(host="host", api_key="123")) @@ -173,6 +171,42 @@ def test_read_file_with_multiple_signals(self): series[3], ) + def test_read_file_with_integer_identifiers(self): + args = common_args + [ + "--filename", + "./exabel_data_sdk/tests/resources/data/timeseries_with_integer_identifiers.csv", + "--namespace", + "acme", + ] + + script = LoadTimeSeriesFromCsv(args, "LoadTest4") + client = mock.create_autospec(ExabelClient(host="host", api_key="123")) + script.run_script(client, script.parse_arguments()) + + call_args_list = client.time_series_api.bulk_upsert_time_series.call_args_list + self.assertEqual(1, len(call_args_list)) + series = call_args_list[0][0][0] + self.assertEqual(2, len(series)) + + pd.testing.assert_series_equal( + pd.Series( + range(1, 6), + pd.date_range("2021-01-01", periods=5, tz=tz.tzutc()), + name="entityTypes/brand/entities/acme.0001/signals/acme.signal1", + ), + series[0], + check_freq=False, + ) + pd.testing.assert_series_equal( + pd.Series( + [4, 5], + pd.DatetimeIndex(["2021-01-01", "2021-01-03"], tz=tz.tzutc()), + name="entityTypes/brand/entities/acme.0002/signals/acme.signal1", + ), + series[1], + check_freq=False, + ) + if __name__ == "__main__": unittest.main() diff --git a/exabel_data_sdk/tests/util/test_resource_name_normalization.py b/exabel_data_sdk/tests/util/test_resource_name_normalization.py index b634325..2479dc3 100644 --- a/exabel_data_sdk/tests/util/test_resource_name_normalization.py +++ b/exabel_data_sdk/tests/util/test_resource_name_normalization.py @@ -7,6 +7,7 @@ from exabel_data_sdk.client.api.entity_api import EntityApi from exabel_data_sdk.client.client_config import ClientConfig from exabel_data_sdk.util.resource_name_normalization import ( + _assert_no_collision, normalize_resource_name, to_entity_resource_names, ) @@ -65,3 +66,66 @@ def test_isin_mapping(self): "Arguments not as expected", ) pd.testing.assert_series_equal(expected, result) + + def test_micticker_mapping(self): + # Note that "NO?COLON" and "TOO:MANY:COLONS" are illegal mic:ticker identifiers, + # since any legal identifier must contain exactly one colon. + # The to_entity_resource_names function will print a warning for such illegal identifiers, + # and they will not result in any searches towards the Exabel API. + data = pd.Series( + [ + "XOSL:TEL", + "XNAS:AAPL", + "NO?COLON", + "TOO:MANY:COLONS", + "XOSL:ORK", + "MANY:HITS", + "NO:HITS", + ], + name="mic:ticker", + ) + expected = pd.Series( + [ + "entityTypes/company/entities/telenor_asa", + "entityTypes/company/entities/apple_inc", + None, + None, + "entityTypes/company/entities/orkla_asa", + None, + None, + ], + name="entity", + ) + entity_api = mock.create_autospec(EntityApi(ClientConfig(api_key="123"), use_json=True)) + entity_api.search_for_entities.side_effect = [ + [Entity("entityTypes/company/entities/telenor_asa", "Telenor ASA")], + [Entity("entityTypes/company/entities/apple_inc", "Apple, Inc.")], + [Entity("entityTypes/company/entities/orkla_asa", "Orkla ASA")], + # Result for "MANY:HITS" + [ + Entity("entityTypes/company/entities/orkla_asa", "Orkla ASA"), + Entity("entityTypes/company/entities/telenor_asa", "Telenor ASA"), + ], + # Result for "NO:HITS" + [], + ] + result = to_entity_resource_names(entity_api, data, namespace="acme") + pd.testing.assert_series_equal(expected, result) + + # Check that the expected searches were performed + call_args_list = entity_api.search_for_entities.call_args_list + expected_searches = ["XOSL:TEL", "XNAS:AAPL", "XOSL:ORK", "MANY:HITS", "NO:HITS"] + self.assertEqual(len(expected_searches), len(call_args_list)) + for i, identifier in enumerate(expected_searches): + mic, ticker = identifier.split(":") + self.assertEqual( + {"entity_type": "entityTypes/company", "mic": mic, "ticker": ticker}, + call_args_list[i][1], + "Arguments not as expected", + ) + + def test_name_collision(self): + bad_mapping = {"Abc!": "Abc_", "Abcd": "Abcd", "Abc?": "Abc_"} + self.assertRaises(SystemExit, _assert_no_collision, bad_mapping) + good_mapping = {"Abc!": "Abc_1", "Abcd": "Abcd", "Abc?": "Abc_2"} + _assert_no_collision(good_mapping) diff --git a/exabel_data_sdk/util/resource_name_normalization.py b/exabel_data_sdk/util/resource_name_normalization.py index b0c9826..ff13780 100644 --- a/exabel_data_sdk/util/resource_name_normalization.py +++ b/exabel_data_sdk/util/resource_name_normalization.py @@ -1,4 +1,6 @@ import re +import sys +from typing import Mapping import pandas as pd @@ -28,6 +30,31 @@ def normalize_resource_name(name: str) -> str: return name +def _assert_no_collision(mapping: Mapping[str, str]) -> None: + """ + Verify that the normalization of identifiers hasn't introduced any name collisions. + If there are collisions, a message is printed informing about the collisions, + and the script exits with an error code of 1. + + Args: + mapping: a map from external identifier to a normalized resource name + + Raises: + SystemExit if there are two identifiers that map to the same resource name + """ + series = pd.Series(mapping) + duplicates = series[series.duplicated(keep=False)] + if duplicates.empty: + # No duplicates, all good + return + print("The normalization of identifiers have introduced resource name collisions.") + print("The collisions are shown below.") + print("Please fix these duplicates, and then re-run the script.") + pd.set_option("max_colwidth", 1000) + print(duplicates.sort_values().to_string()) + sys.exit(1) + + def to_entity_resource_names( entity_api: EntityApi, identifiers: pd.Series, namespace: str = None ) -> pd.Series: @@ -37,25 +64,36 @@ def to_entity_resource_names( The name of the given series is used to determine what kind of identifier it is. These are the legal series names, and how each case is handled: - - entity (or entity_from or entity_to): + - entity (or entity_from or entity_to) The given identifiers are the entity resource names. The identifiers are returned unaltered. - - isin: + - isin The given identifiers are ISIN numbers. The ISIN numbers are looked up with the Exabel API, and the Exabel resource identifiers for the associated companies are returned. - - factset_identifier: + - factset_identifier The given identifiers are FactSet IDs. The identifiers are looked up with the Exabel API, and the Exabel resource identifiers are returned. - - bloomberg_ticker: + - bloomberg_ticker The given identifiers are Bloomberg tickers. The tickers are looked up with the Exabel API, and the Exabel resource identifiers are returned. + - mic:ticker + The given identifiers are the combination of MIC and stock ticker, separated by a colon. + MIC is the Market Identifier Code of the stock exchange where the stock is traded under + the given ticker. + The MIC/ticker combinations are looked up with the Exabel API, and the Exabel resource + identifiers are returned. + Examples: + XNAS:AAPL refers to Apple, Inc. on NASDAQ + XNYS:GE refers to General Electric Co. on the New York Stock Exchange + XOSL:TEL refers to Telenor ASA on the Oslo Stock Exchange + - any known entity type, e.g. "brand" or "product_type": The given identifiers are customer provided names. The names are first normalized (using the normalize_resource_name method) @@ -75,7 +113,7 @@ def to_entity_resource_names( unique_ids = identifiers.unique() - if name in ("isin", "factset_identifier", "bloomberg_ticker"): + if name in ("isin", "factset_identifier", "bloomberg_ticker", "mic:ticker"): # A company identifier print(f"Looking up {len(unique_ids)} {name}s...") mapping = {} @@ -83,7 +121,14 @@ def to_entity_resource_names( if not identifier: # Skip empty identifiers continue - search_terms = {name: identifier} + if name == "mic:ticker": + parts = identifier.split(":") + if len(parts) != 2: + print("mic:ticker must contain exactly one colon (:), but got:", identifier) + continue + search_terms = {"mic": parts[0], "ticker": parts[1]} + else: + search_terms = {name: identifier} entities = entity_api.search_for_entities( entity_type="entityTypes/company", **search_terms ) @@ -118,6 +163,7 @@ def to_entity_resource_names( for identifier in unique_ids if identifier } + _assert_no_collision(mapping) result = identifiers.map(mapping) result.name = "entity"