diff --git a/src/api/README.md b/src/api/README.md index 4ff8d7683..0f4c94167 100644 --- a/src/api/README.md +++ b/src/api/README.md @@ -34,6 +34,34 @@ Ensure that you setup the **local.settings.json** file with the relevant paramet |---------|-------| |DATABRICKS_SQL_SERVER_HOSTNAME|adb-xxxxx.x.azuredatabricks.net| |DATABRICKS_SQL_HTTP_PATH|/sql/1.0/warehouses/xxx| +|DATABRICKS_SERVING_ENDPOINT|https://adb-xxxxx.x.azuredatabricks.net/serving-endpoints/xxxxxxx/invocations| +|BATCH_THREADPOOL_WORKERS|3| +|LOOKUP_THREADPOOL_WORKERS|10| + +### Information: + +DATABRICKS_SERVING_ENDPOINT +- **This is an optional parameter** +- This represents a Databricks feature serving endpont, which is used to create lower-latency look-ups of databricks tables. +- In this API, this is used to map tagnames to their respective "CatalogName", "SchemaName" and "DataTable" +- This enables the parameters of business_unit, asset and data_security_level to be optional, thereby reducing user friction in querying data. +- Given these parameters are optional, custom validation logic based on the presence (or not) of the mapping endpoint is done in the models.py via pydantic. +- For more information on feature serving endpoints please see: https://docs.databricks.com/en/machine-learning/feature-store/feature-function-serving.html + +LOOKUP_THREADPOOL_WORKERS +- **This is an optional parameter** +- In the event of a query with multiple tags residing in multiple tables, the api will query these tables separately and the results will be concatenated. +- This parameter will parallelise these requests. +- This defaults to 3 if it is not defined in the .env. + +BATCH_THREADPOOL_WORKERS +- **This is an optional parameter** +- This represents the number of workers for parallelisation of requests in a batch sent to the /batch route. +- This defaults to the cpu count minus one if not defined in the .env. + +Please note that the batch API route calls the lookup under the hood by default. Therefore if there are many requests, with each requiring multiple tables the total number of threads will be up to BATCH_THREADPOOL_WORKERS * LOOKUP_THREADPOOL_WORKERS. +For example, 10 requests in the batch with each querying 3 tables means there will be up to 30 simulatanous queries. +Therefore, it is recommended to set these parameters for performance optimization. Please also ensure to install all the turbodbc requirements for your machine by reviewing the [installation instructions](https://turbodbc.readthedocs.io/en/latest/pages/getting_started.html) of turbodbc. On a macbook, this includes executing the following commands: diff --git a/src/api/requirements.txt b/src/api/requirements.txt index 76a6f8767..9e30382de 100644 --- a/src/api/requirements.txt +++ b/src/api/requirements.txt @@ -21,4 +21,4 @@ googleapis-common-protos>=1.56.4 langchain>=0.2.0,<0.3.0 langchain-community>=0.2.0,<0.3.0 openai==1.13.3 -pyjwt==2.8.0 \ No newline at end of file +pyjwt==2.8.0 diff --git a/src/api/v1/__init__.py b/src/api/v1/__init__.py index 37f5865af..cd8848c67 100644 --- a/src/api/v1/__init__.py +++ b/src/api/v1/__init__.py @@ -30,6 +30,7 @@ circular_average, circular_standard_deviation, summary, + batch, ) from src.api.auth.azuread import oauth2_scheme diff --git a/src/api/v1/batch.py b/src/api/v1/batch.py new file mode 100755 index 000000000..2e42388d8 --- /dev/null +++ b/src/api/v1/batch.py @@ -0,0 +1,144 @@ +# Copyright 2022 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import numpy as np +import os +from fastapi import HTTPException, Depends, Body # , JSONResponse + +from src.api.v1.models import ( + BaseQueryParams, + BaseHeaders, + BatchBodyParams, + BatchResponse, + LimitOffsetQueryParams, + HTTPError, +) +from src.api.auth.azuread import oauth2_scheme +from src.api.v1.common import ( + common_api_setup_tasks, + json_response_batch, + lookup_before_get, +) +from src.api.FastAPIApp import api_v1_router +from src.api.v1.common import lookup_before_get +from concurrent.futures import * + + +ROUTE_FUNCTION_MAPPING = { + "/api/v1/events/raw": "raw", + "/api/v1/events/latest": "latest", + "/api/v1/events/resample": "resample", + "/api/v1/events/plot": "plot", + "/api/v1/events/interpolate": "interpolate", + "/api/v1/events/interpolationattime": "interpolationattime", + "/api/v1/events/circularaverage": "circularaverage", + "/api/v1/events/circularstandarddeviation": "circularstandarddeviation", + "/api/v1/events/timeweightedaverage": "timeweightedaverage", + "/api/v1/events/summary": "summary", + "/api/v1/events/metadata": "metadata", + "/api/v1/sql/execute": "execute", +} + + +async def batch_events_get( + base_query_parameters, base_headers, batch_query_parameters, limit_offset_parameters +): + try: + (connection, parameters) = common_api_setup_tasks( + base_query_parameters=base_query_parameters, + base_headers=base_headers, + ) + + # Validate the parameters + parsed_requests = [] + for request in batch_query_parameters.requests: + # If required, combine request body and parameters: + parameters = request["params"] + if request["method"] == "POST": + if request["body"] == None: + raise Exception( + "Incorrectly formatted request provided: All POST requests require a body" + ) + parameters = {**parameters, **request["body"]} + + # Map the url to a specific function + try: + func = ROUTE_FUNCTION_MAPPING[request["url"]] + except: + raise Exception( + "Unsupported url: Only relative base urls are supported. Please provide any parameters in the params key" + ) + + # Rename tag_name to tag_names, if required + if "tag_name" in parameters.keys(): + parameters["tag_names"] = parameters.pop("tag_name") + + # Append to array + parsed_requests.append({"func": func, "parameters": parameters}) + + # Obtain max workers from environment var, otherwise default to one less than cpu count + max_workers = os.environ.get("BATCH_THREADPOOL_WORKERS", os.cpu_count() - 1) + + # Request the data for each concurrently with threadpool + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Use executor.map to preserve order + results = executor.map( + lambda arguments: lookup_before_get(*arguments), + [ + (parsed_request["func"], connection, parsed_request["parameters"]) + for parsed_request in parsed_requests + ], + ) + + return json_response_batch(results) + + except Exception as e: + print(e) + logging.error(str(e)) + raise HTTPException(status_code=400, detail=str(e)) + + +post_description = """ +## Batch + +Retrieval of timeseries data via a POST method to enable providing a list of requests including the route and parameters +""" + + +@api_v1_router.post( + path="/events/batch", + name="Batch POST", + description=post_description, + tags=["Events"], + dependencies=[Depends(oauth2_scheme)], + responses={200: {"model": BatchResponse}, 400: {"model": HTTPError}}, + openapi_extra={ + "externalDocs": { + "description": "RTDIP Batch Query Documentation", + "url": "https://www.rtdip.io/sdk/code-reference/query/functions/time_series/batch/", + } + }, +) +async def batch_post( + base_query_parameters: BaseQueryParams = Depends(), + batch_query_parameters: BatchBodyParams = Body(default=...), + base_headers: BaseHeaders = Depends(), + limit_offset_query_parameters: LimitOffsetQueryParams = Depends(), +): + return await batch_events_get( + base_query_parameters, + base_headers, + batch_query_parameters, + limit_offset_query_parameters, + ) diff --git a/src/api/v1/circular_average.py b/src/api/v1/circular_average.py index e9ccf12a3..382f2a32d 100644 --- a/src/api/v1/circular_average.py +++ b/src/api/v1/circular_average.py @@ -32,7 +32,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def circular_average_events_get( @@ -55,7 +55,15 @@ def circular_average_events_get( base_headers=base_headers, ) - data = circular_average.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = circular_average.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("circular_average", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/circular_standard_deviation.py b/src/api/v1/circular_standard_deviation.py index 6069e4081..836a958a6 100644 --- a/src/api/v1/circular_standard_deviation.py +++ b/src/api/v1/circular_standard_deviation.py @@ -33,7 +33,7 @@ LimitOffsetQueryParams, CircularAverageQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def circular_standard_deviation_events_get( @@ -56,7 +56,17 @@ def circular_standard_deviation_events_get( base_headers=base_headers, ) - data = circular_standard_deviation.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = circular_standard_deviation.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get( + "circular_standard_deviation", connection, parameters + ) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/common.py b/src/api/v1/common.py index 4aac0eaa8..1e70753b0 100644 --- a/src/api/v1/common.py +++ b/src/api/v1/common.py @@ -15,23 +15,35 @@ from datetime import datetime import json import os -import pandas as pd import importlib.util -from typing import Any + +from typing import Any, List, Dict, Union +import requests +import json +import pandas as pd +import numpy as np + from fastapi import Response +from fastapi.responses import JSONResponse import dateutil.parser + from pandas import DataFrame import pyarrow as pa from pandas.io.json import build_table_schema + from src.sdk.python.rtdip_sdk.connectors import ( DatabricksSQLConnection, ConnectionReturnType, ) +from src.sdk.python.rtdip_sdk.queries.time_series import batch + + if importlib.util.find_spec("turbodbc") != None: from src.sdk.python.rtdip_sdk.connectors import TURBODBCSQLConnection from src.api.auth import azuread from .models import BaseHeaders, FieldSchema, LimitOffsetQueryParams, PaginationRow +from decimal import Decimal def common_api_setup_tasks( # NOSONAR @@ -184,23 +196,206 @@ def datetime_parser(json_dict): def json_response( - data: dict, limit_offset_parameters: LimitOffsetQueryParams + data: Union[dict, DataFrame], limit_offset_parameters: LimitOffsetQueryParams ) -> Response: - schema_df = pd.DataFrame() - if data["data"] is not None and data["data"] != "": - json_str = data["sample_row"] - json_dict = json.loads(json_str, object_hook=datetime_parser) - schema_df = pd.json_normalize(json_dict) - - return Response( - content="{" - + '"schema":{},"data":{},"pagination":{}'.format( - FieldSchema.model_validate( - build_table_schema(schema_df, index=False, primary_key=False), - ).model_dump_json(), - "[" + data["data"] + "]", - pagination(limit_offset_parameters, data["count"]).model_dump_json(), + if isinstance(data, DataFrame): + return Response( + content="{" + + '"schema":{},"data":{},"pagination":{}'.format( + FieldSchema.model_validate( + build_table_schema(data, index=False, primary_key=False), + ).model_dump_json(), + data.replace({np.nan: None}).to_json( + orient="records", date_format="iso", date_unit="ns" + ), + pagination(limit_offset_parameters, data).model_dump_json(), + ) + + "}", + media_type="application/json", + ) + else: + schema_df = pd.DataFrame() + if data["data"] is not None and data["data"] != "": + json_str = data["sample_row"] + json_dict = json.loads(json_str, object_hook=datetime_parser) + schema_df = pd.json_normalize(json_dict) + + return Response( + content="{" + + '"schema":{},"data":{},"pagination":{}'.format( + FieldSchema.model_validate( + build_table_schema(schema_df, index=False, primary_key=False), + ).model_dump_json(), + "[" + data["data"] + "]", + pagination(limit_offset_parameters, data["count"]).model_dump_json(), + ) + + "}", + media_type="application/json", + ) + + +def json_response_batch(data_list: List[DataFrame]) -> Response: + # Function to parse dataframe into dictionary along with schema + def get_as_dict(data): + def convert_value(x): + if isinstance(x, pd.Timestamp): + return x.isoformat(timespec="nanoseconds") + elif isinstance(x, pd.Timedelta): + return x.isoformat() + elif isinstance(x, Decimal): + return float(x) + return x + + data_parsed = data.replace({np.nan: None}).map(convert_value) + schema = build_table_schema(data_parsed, index=False, primary_key=False) + data_dict = data_parsed.to_dict(orient="records") + + return {"schema": schema, "data": data_dict} + + # Parse each dataframe into a dictionary containing the schema and the data as dict + dict_content = {"data": [get_as_dict(data) for data in data_list]} + + return JSONResponse(content=dict_content) + + +def lookup_before_get( + func_name: str, connection: DatabricksSQLConnection, parameters: Dict +): + # Ensure returns data as DataFrames + parameters["to_json"] = False + + # query mapping endpoint for tablenames - returns tags as array under each table key + tag_table_mapping = query_mapping_endpoint( + tags=parameters["tag_names"], + mapping_endpoint=os.getenv("DATABRICKS_SERVING_ENDPOINT"), + connection=connection, + ) + + # create list of parameter dicts for each table + request_list = [] + for table in tag_table_mapping: + params = parameters.copy() + params["tag_names"] = tag_table_mapping[table] + params.update( + split_table_name(table) + ) # Adds business_unit, asset, data_security_level, data_type + request = {"type": func_name, "parameters_dict": params} + request_list.append(request) + + # make default workers 3 as within one query typically will request from only a few tables at once + max_workers = os.environ.get("LOOKUP_THREADPOOL_WORKERS", 3) + + # run function with each parameters concurrently + results = batch.get(connection, request_list, threadpool_max_workers=max_workers) + + # Append/concat results as required + data = concatenate_dfs_and_order( + dfs_arr=results, pivot=False, tags=parameters["tag_names"] + ) + + return data + + +def query_mapping_endpoint(tags: list, mapping_endpoint: str, connection: Dict): + # Form header dict with token from connection + token = swap_for_databricks_token(connection.access_token) + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + # Create body of request + data = {"dataframe_records": [{"TagName": tag} for tag in tags]} + data_json = json.dumps(data, allow_nan=True) + + # Make request to mapping endpoint + response = requests.post(headers=headers, url=mapping_endpoint, data=data_json) + if response.status_code != 200: + raise Exception( + f"Request failed with status {response.status_code}, {response.text}" + ) + result = response.json() + + # Map tags to tables, where all tags belonging to each table are stored in an array + tag_table_mapping = {} + for row in result["outputs"]: + # Check results are returned + if any(row[x] == None for x in ["CatalogName", "SchemaName", "DataTable"]): + raise Exception( + f"One or more tags do not have tables associated with them, the data belongs to a confidential table, or you do not have access. If the tag belongs to a confidential table and you do have access, please supply the business_unit, asset, data_security_level and data_type" + ) + + # Construct full tablename from output + table_name = f"""{row["CatalogName"]}.{row["SchemaName"]}.{row["DataTable"]}""" + + # Store table names along with tags in dict (all tags that share table under same key) + if table_name not in tag_table_mapping: + tag_table_mapping[table_name] = [] + + tag_table_mapping[table_name].append(row["TagName"]) + + return tag_table_mapping + + +def split_table_name(str): + try: + # Retireve parts by splitting string + parts = str.split(".") + business_unit = parts[0] + schema = parts[1] + asset_security_type = parts[2].split("_") + + # check if of correct format + if schema != "sensors" and ("events" not in str or "metadata" not in str): + raise Exception() + + # Get the asset, data security level and type + asset = asset_security_type[0].lower() + data_security_level = asset_security_type[1].lower() + data_type = asset_security_type[ + len(asset_security_type) - 1 + ].lower() # i.e. the final part + + # Return the formatted object + return { + "business_unit": business_unit, + "asset": asset, + "data_security_level": data_security_level, + "data_type": data_type, + } + except Exception as e: + raise Exception( + "Unsupported table name format supplied. Please use the format 'businessunit.schema.asset.datasecurityevel_events_datatype" ) - + "}", - media_type="application/json", + + +def concatenate_dfs_and_order(dfs_arr: List[DataFrame], pivot: bool, tags: list): + if pivot: + # If pivoted, then must add columns horizontally + concat_df = pd.concat(dfs_arr, axis=1, ignore_index=False) + concat_df = concat_df.loc[:, ~concat_df.columns.duplicated()] + + # reorder columns so that they match the order of the tags provided + time_col = concat_df.columns.to_list()[0] + cols = [time_col, *tags] + concat_df = concat_df[cols] + + else: + # Otherwise, can concat vertically + concat_df = pd.concat(dfs_arr, axis=0, ignore_index=True) + + return concat_df + + +def swap_for_databricks_token(azure_ad_token): + DATABRICKS_SQL_SERVER_HOSTNAME = os.getenv("DATABRICKS_SQL_SERVER_HOSTNAME") + + token_response = requests.post( + f"https://{DATABRICKS_SQL_SERVER_HOSTNAME}/api/2.0/token/create", + headers={"Authorization": f"Bearer {azure_ad_token}"}, + json={"comment": "tag mapping token", "lifetime_seconds": 360}, ) + + if token_response.status_code == 200: + DATABRICKS_TOKEN = token_response.json().get("token_value") + else: + DATABRICKS_TOKEN = "" + + return DATABRICKS_TOKEN diff --git a/src/api/v1/interpolate.py b/src/api/v1/interpolate.py index f0a89dfc5..a25d097b1 100644 --- a/src/api/v1/interpolate.py +++ b/src/api/v1/interpolate.py @@ -33,7 +33,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def interpolate_events_get( @@ -58,7 +58,15 @@ def interpolate_events_get( base_headers=base_headers, ) - data = interpolate.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = interpolate.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("interpolate", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/interpolation_at_time.py b/src/api/v1/interpolation_at_time.py index c41f53033..cc812bc25 100644 --- a/src/api/v1/interpolation_at_time.py +++ b/src/api/v1/interpolation_at_time.py @@ -30,7 +30,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def interpolation_at_time_events_get( @@ -51,7 +51,15 @@ def interpolation_at_time_events_get( base_headers=base_headers, ) - data = interpolation_at_time.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = interpolation_at_time.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("interpolation_at_time", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/latest.py b/src/api/v1/latest.py index db5cdaa57..e39bb4ed7 100644 --- a/src/api/v1/latest.py +++ b/src/api/v1/latest.py @@ -27,7 +27,7 @@ HTTPError, ) from src.api.auth.azuread import oauth2_scheme -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get from src.api.FastAPIApp import api_v1_router @@ -42,7 +42,15 @@ def latest_retrieval_get( base_headers=base_headers, ) - data = latest.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level"] + ): + # if have all required params, run normally + data = latest.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("latest", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/metadata.py b/src/api/v1/metadata.py index dd2595dc0..4470e8dca 100644 --- a/src/api/v1/metadata.py +++ b/src/api/v1/metadata.py @@ -25,7 +25,7 @@ HTTPError, ) from src.api.auth.azuread import oauth2_scheme -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get from src.api.FastAPIApp import api_v1_router @@ -40,7 +40,15 @@ def metadata_retrieval_get( base_headers=base_headers, ) - data = metadata.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level"] + ): + # if have all required params, run normally + data = metadata.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("metadata", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/models.py b/src/api/v1/models.py index 61f0dd2b4..34461f48b 100644 --- a/src/api/v1/models.py +++ b/src/api/v1/models.py @@ -22,12 +22,13 @@ Field, Strict, field_serializer, + BaseModel, ) from typing import Annotated, List, Union, Dict, Any -from fastapi import Query, Header, Depends +from fastapi import Query, Header, Depends, HTTPException from datetime import date from src.api.auth.azuread import oauth2_scheme -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Optional EXAMPLE_DATE = "2022-01-01" @@ -230,12 +231,21 @@ def __init__( class BaseQueryParams: def __init__( self, - business_unit: str = Query(..., description="Business Unit Name"), + business_unit: str = Query(None, description="Business Unit Name"), region: str = Query(..., description="Region"), - asset: str = Query(..., description="Asset"), - data_security_level: str = Query(..., description="Data Security Level"), + asset: str = Query(None, description="Asset"), + data_security_level: str = Query(None, description="Data Security Level"), authorization: str = Depends(oauth2_scheme), ): + # Additional validation when mapping endpoint not provided - ensure validation error for missing params + if not os.getenv("DATABRICKS_SERVING_ENDPOINT"): + required_params = { + "business_unit": business_unit, + "asset": asset, + "data_security_level": data_security_level, + } + additionaly_validate_params(required_params) + self.business_unit = business_unit self.region = region self.asset = asset @@ -258,11 +268,29 @@ def check_date(v: str) -> str: return v +def additionaly_validate_params(required_params): + # Checks if any of the supplied parameters are missing, and throws HTTPException in pydantic format + errors = [] + for field in required_params.keys(): + if required_params[field] is None: + errors.append( + { + "type": "missing", + "loc": ("query", field), + "msg": "Field required", + "input": required_params[field], + } + ) + if len(errors) > 0: + print(errors) + raise HTTPException(status_code=422, detail=errors) + + class RawQueryParams: def __init__( self, data_type: str = Query( - ..., + None, description="Data Type can be one of the following options: float, double, integer, string", examples=["float", "double", "integer", "string"], ), @@ -282,6 +310,11 @@ def __init__( examples=[EXAMPLE_DATE, EXAMPLE_DATETIME, EXAMPLE_DATETIME_TIMEZOME], ), ): + # Additional validation when mapping endpoint not provided - ensure validation error for missing params + if not os.getenv("DATABRICKS_SERVING_ENDPOINT"): + required_params = {"data_type": data_type} + additionaly_validate_params(required_params) + self.data_type = data_type self.include_bad_data = include_bad_data self.start_date = start_date @@ -402,7 +435,8 @@ def __init__( self, data_type: str = Query( ..., - description="Data Type can be one of the following options:[float, double, integer, string]", + description="Data Type can be one of the following options: float, double, integer, string", + examples=["float", "double", "integer", "string"], ), timestamps: List[Union[date, datetime]] = Query( ..., @@ -416,6 +450,11 @@ def __init__( ..., description="Include or remove Bad data points" ), ): + # Additional validation when mapping endpoint not provided - ensure validation error for missing params + if not os.getenv("DATABRICKS_SERVING_ENDPOINT"): + required_params = {"data_type": data_type} + additionaly_validate_params(required_params) + self.data_type = data_type self.timestamps = timestamps self.window_length = window_length @@ -465,3 +504,29 @@ def __init__( self.time_interval_unit = time_interval_unit self.lower_bound = lower_bound self.upper_bound = upper_bound + + +class BatchDict(BaseModel): + url: str + method: str + params: dict + body: dict = None + + def __getitem__(self, item): + if item in self.__dict__: + return self.__dict__[item] + else: + raise KeyError(f"Key {item} not found in the model.") + + +class BatchBodyParams(BaseModel): + requests: List[BatchDict] + + +class BatchResponse(BaseModel): + schema: FieldSchema = Field(None, alias="schema", serialization_alias="schema") + data: List + + +class BatchListResponse(BaseModel): + data: List[BatchResponse] diff --git a/src/api/v1/plot.py b/src/api/v1/plot.py index 9a665532d..63378914b 100644 --- a/src/api/v1/plot.py +++ b/src/api/v1/plot.py @@ -31,7 +31,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def plot_events_get( @@ -52,7 +52,15 @@ def plot_events_get( base_headers=base_headers, ) - data = plot.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = plot.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("plot", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/raw.py b/src/api/v1/raw.py index a3d960f8c..2267a4151 100644 --- a/src/api/v1/raw.py +++ b/src/api/v1/raw.py @@ -27,7 +27,7 @@ HTTPError, ) from src.api.auth.azuread import oauth2_scheme -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get from src.api.FastAPIApp import api_v1_router @@ -47,7 +47,15 @@ def raw_events_get( base_headers=base_headers, ) - data = raw.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = raw.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("raw", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/resample.py b/src/api/v1/resample.py index 9b0059351..d3789a72a 100644 --- a/src/api/v1/resample.py +++ b/src/api/v1/resample.py @@ -32,7 +32,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def resample_events_get( @@ -55,7 +55,15 @@ def resample_events_get( base_headers=base_headers, ) - data = resample.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = resample.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("resample", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/summary.py b/src/api/v1/summary.py index a42c75041..ce8400e63 100644 --- a/src/api/v1/summary.py +++ b/src/api/v1/summary.py @@ -27,7 +27,7 @@ HTTPError, ) from src.api.auth.azuread import oauth2_scheme -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get from src.api.FastAPIApp import api_v1_router @@ -47,7 +47,15 @@ def summary_events_get( base_headers=base_headers, ) - data = summary.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = summary.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("summary", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/api/v1/time_weighted_average.py b/src/api/v1/time_weighted_average.py index ac1893d11..dac0759cc 100644 --- a/src/api/v1/time_weighted_average.py +++ b/src/api/v1/time_weighted_average.py @@ -30,7 +30,7 @@ PivotQueryParams, LimitOffsetQueryParams, ) -from src.api.v1.common import common_api_setup_tasks, json_response +from src.api.v1.common import common_api_setup_tasks, json_response, lookup_before_get def time_weighted_average_events_get( @@ -53,7 +53,15 @@ def time_weighted_average_events_get( base_headers=base_headers, ) - data = time_weighted_average.get(connection, parameters) + if all( + (key in parameters and parameters[key] != None) + for key in ["business_unit", "asset", "data_security_level", "data_type"] + ): + # if have all required params, run normally + data = time_weighted_average.get(connection, parameters) + else: + # else wrap in lookup function that finds tablenames and runs function (if mutliple tables, handles concurrent requests) + data = lookup_before_get("time_weighted_average", connection, parameters) return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/sdk/python/rtdip_sdk/queries/time_series/batch.py b/src/sdk/python/rtdip_sdk/queries/time_series/batch.py new file mode 100755 index 000000000..56225ec4d --- /dev/null +++ b/src/sdk/python/rtdip_sdk/queries/time_series/batch.py @@ -0,0 +1,84 @@ +# Copyright 2022 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List +import logging +import pandas as pd +from ._time_series_query_builder import _query_builder +from ...connectors.odbc.db_sql_connector import DatabricksSQLConnection +from concurrent.futures import * + + +def get( + connection: object, request_list: List[dict], threadpool_max_workers=1 +) -> List[pd.DataFrame]: + """ + A function to return back raw data by querying databricks SQL Warehouse using a connection specified by the user. + + The available connectors by RTDIP are Databricks SQL Connect, PYODBC SQL Connect, TURBODBC SQL Connect. + + The available authentication methods are Certificate Authentication, Client Secret Authentication or Default Authentication. See documentation. + + Args: + connection: Connection chosen by the user (Databricks SQL Connect, PYODBC SQL Connect, TURBODBC SQL Connect) + request_list: A list of dictionaries, each contaiing the type of request and a dictionary of parameters. + + Returns: + DataFrame: A list of dataframes of timeseries data. + + """ + try: + results = [] + + # Get connection parameters and close, as each thread will create new connection + server_hostname = connection.server_hostname + http_path = connection.http_path + access_token = connection.access_token + connection.close() + + def execute_request(connection_params, request): + # Create connection and cursor + connection = DatabricksSQLConnection(*connection_params) + cursor = connection.cursor() + + # Build query with query builder + query = _query_builder(request["parameters_dict"], request["type"]) + + # Execute query + try: + cursor.execute(query) + df = cursor.fetch_all() + return df + except Exception as e: + logging.exception("error returning dataframe") + raise e + finally: + # Close cursor and connection at end + cursor.close() + connection.close() + + with ThreadPoolExecutor(max_workers=threadpool_max_workers) as executor: + # Package up connection params into tuple + connection_params = (server_hostname, http_path, access_token) + + # Execute queries with threadpool - map preserves order + results = executor.map( + lambda arguments: execute_request(*arguments), + [(connection_params, request) for request in request_list], + ) + + return results + + except Exception as e: + logging.exception("error with batch function") + raise e diff --git a/tests/api/v1/api_test_objects.py b/tests/api/v1/api_test_objects.py index 7c414fa2d..098855a78 100644 --- a/tests/api/v1/api_test_objects.py +++ b/tests/api/v1/api_test_objects.py @@ -19,6 +19,7 @@ from tests.sdk.python.rtdip_sdk.queries.time_series._test_base import ( DATABRICKS_SQL_CONNECT, ) +import os START_DATE = "2011-01-01T00:00:00+00:00" END_DATE = "2011-01-02T00:00:00+00:00" @@ -230,12 +231,158 @@ } -def mocker_setup(mocker: MockerFixture, patch_method, test_data, side_effect=None): +# Batch api test parameters +BATCH_MOCKED_PARAMETER_DICT = { + "region": "mocked-region", +} + +BATCH_POST_PAYLOAD_SINGLE_WITH_GET = { + "requests": [ + { + "url": "/api/v1/events/summary", + "method": "GET", + "headers": TEST_HEADERS, + "params": SUMMARY_MOCKED_PARAMETER_DICT, + } + ] +} + +BATCH_POST_PAYLOAD_SINGLE_WITH_POST = { + "requests": [ + { + "url": "/api/v1/events/raw", + "method": "POST", + "headers": TEST_HEADERS, + "params": RAW_MOCKED_PARAMETER_DICT, + "body": RESAMPLE_POST_BODY_MOCKED_PARAMETER_DICT, + } + ] +} + +BATCH_POST_PAYLOAD_SINGLE_WITH_GET_ERROR_DICT = { + "requests": [ + { + "url": "an_unsupported_route", + "method": "GET", + "headers": TEST_HEADERS, + "params": SUMMARY_MOCKED_PARAMETER_DICT, + } + ] +} + +BATCH_POST_PAYLOAD_SINGLE_WITH_POST_ERROR_DICT = { + "requests": [ + { + "url": "/api/v1/events/raw", + "method": "POST", + "headers": TEST_HEADERS, + "params": RAW_MOCKED_PARAMETER_DICT, + # No body supplied + } + ] +} + +BATCH_POST_PAYLOAD_MULTIPLE = { + "requests": [ + { + "url": "/api/v1/events/summary", + "method": "GET", + "headers": TEST_HEADERS, + "params": SUMMARY_MOCKED_PARAMETER_DICT, + }, + { + "url": "/api/v1/events/raw", + "method": "POST", + "headers": TEST_HEADERS, + "params": RAW_MOCKED_PARAMETER_DICT, + "body": RESAMPLE_POST_BODY_MOCKED_PARAMETER_DICT, + }, + ] +} + +# Tag mapping test parameters + +MOCK_TAG_MAPPING_SINGLE = { + "outputs": [ + { + "TagName": "Tagname1", + "CatalogName": "rtdip", + "SchemaName": "sensors", + "DataTable": "asset1_restricted_events_float", + } + ] +} + +MOCK_TAG_MAPPING_MULTIPLE = { + "outputs": [ + { + "TagName": "Tagname1", + "CatalogName": "rtdip", + "SchemaName": "sensors", + "DataTable": "asset1_restricted_events_float", + }, + { + "TagName": "Tagname2", + "CatalogName": "rtdip", + "SchemaName": "sensors", + "DataTable": "asset1_restricted_events_float", + }, + { + "TagName": "Tagname3", + "CatalogName": "rtdip", + "SchemaName": "sensors", + "DataTable": "asset2_restricted_events_integer", + }, + ] +} + +MOCK_TAG_MAPPING_EMPTY = { + "outputs": [ + { + "TagName": "Tagname1", + "CatalogName": None, + "SchemaName": None, + "DataTable": None, + } + ] +} + +MOCK_TAG_MAPPING_BODY = {"dataframe_records": [{"TagName": "MOCKED-TAGNAME1"}]} + +MOCK_MAPPING_ENDPOINT_URL = "https://mockdatabricksmappingurl.com/serving-endpoints/metadata-mapping/invocations" + + +# Mocker set-up utility + + +def mocker_setup( + mocker: MockerFixture, + patch_method, + test_data, + side_effect=None, + patch_side_effect=None, + tag_mapping_data=None, +): mocker.patch( DATABRICKS_SQL_CONNECT, return_value=MockedDBConnection(), side_effect=side_effect, ) - mocker.patch(patch_method, return_value=test_data) + + if patch_side_effect is not None: + mocker.patch(patch_method, side_effect=patch_side_effect) + else: + mocker.patch(patch_method, return_value=test_data) + mocker.patch("src.api.auth.azuread.get_azure_ad_token", return_value="token") + + # Create a mock response object for tag mapping endpoint with a .json() method that returns the mock data + if tag_mapping_data is not None: + mock_response = mocker.MagicMock() + mock_response.json.return_value = tag_mapping_data + mock_response.status_code = 200 + + # Patch 'requests.post' to return the mock response + mocker.patch("requests.post", return_value=mock_response) + return mocker diff --git a/tests/api/v1/test_api_batch.py b/tests/api/v1/test_api_batch.py new file mode 100644 index 000000000..48f679117 --- /dev/null +++ b/tests/api/v1/test_api_batch.py @@ -0,0 +1,310 @@ +# Copyright 2022 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import pytest +from pytest_mock import MockerFixture +import pandas as pd +import numpy as np +from datetime import datetime, timezone +from tests.api.v1.api_test_objects import ( + BATCH_MOCKED_PARAMETER_DICT, + BATCH_POST_PAYLOAD_SINGLE_WITH_GET, + BATCH_POST_PAYLOAD_SINGLE_WITH_POST, + BATCH_POST_PAYLOAD_SINGLE_WITH_GET_ERROR_DICT, + BATCH_POST_PAYLOAD_SINGLE_WITH_POST_ERROR_DICT, + BATCH_POST_PAYLOAD_MULTIPLE, + mocker_setup, + TEST_HEADERS, + BASE_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_MAPPING_ENDPOINT_URL, +) +from src.api.v1.models import ( + RawResponse, +) +from pandas.io.json import build_table_schema +from httpx import AsyncClient +from src.api.v1 import app +from src.api.v1.common import json_response_batch + +MOCK_METHOD = "src.sdk.python.rtdip_sdk.queries.time_series.raw.get" +MOCK_API_NAME = "/api/v1/events/batch" + +pytestmark = pytest.mark.anyio + + +async def test_api_batch_single_get_success(mocker: MockerFixture): + """ + Case when single get request supplied in array of correct format + """ + + test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "Count": [10.0], + "Avg": [5.05], + "Min": [1.0], + "Max": [10.0], + "StDev": [3.02], + "Sum": [25.0], + "Var": [0.0], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=BATCH_MOCKED_PARAMETER_DICT, + json=BATCH_POST_PAYLOAD_SINGLE_WITH_GET, + ) + + # Define full expected structure for one test - for remainder use json_response_batch as already tested in common + expected = { + "data": [ + { + "schema": { + "fields": [ + {"name": "TagName", "type": "string"}, + {"name": "Count", "type": "number"}, + {"name": "Avg", "type": "number"}, + {"name": "Min", "type": "number"}, + {"name": "Max", "type": "number"}, + {"name": "StDev", "type": "number"}, + {"name": "Sum", "type": "number"}, + {"name": "Var", "type": "number"}, + ], + "primaryKey": False, + "pandas_version": "1.4.0", + }, + "data": [ + { + "TagName": "TestTag", + "Count": 10.0, + "Avg": 5.05, + "Min": 1.0, + "Max": 10.0, + "StDev": 3.02, + "Sum": 25.0, + "Var": 0.0, + } + ], + } + ] + } + + assert actual.json() == expected + assert actual.status_code == 200 + + +async def test_api_batch_single_post_success(mocker: MockerFixture): + """ + Case when single post request supplied in array of correct format + """ + + test_data = pd.DataFrame( + { + "EventTime": [datetime.now(timezone.utc)], + "TagName": ["TestTag"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=BATCH_MOCKED_PARAMETER_DICT, + json=BATCH_POST_PAYLOAD_SINGLE_WITH_POST, + ) + + expected = json.loads(json_response_batch([test_data]).body.decode("utf-8")) + + assert actual.json() == expected + assert actual.status_code == 200 + + +async def test_api_batch_single_get_unsupported_route_error(mocker: MockerFixture): + """ + Case when single post request supplied but route not supported + """ + + test_data = pd.DataFrame( + { + "EventTime": [datetime.now(timezone.utc)], + "TagName": ["TestTag"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=BATCH_MOCKED_PARAMETER_DICT, + json=BATCH_POST_PAYLOAD_SINGLE_WITH_GET_ERROR_DICT, + ) + + expected = { + "detail": "Unsupported url: Only relative base urls are supported. Please provide any parameters in the params key" + } + + assert actual.json() == expected + assert actual.status_code == 400 + + +async def test_api_batch_single_post_missing_body_error(mocker: MockerFixture): + """ + Case when single post request supplied in array of incorrect format (missing payload) + """ + + test_data = pd.DataFrame( + { + "EventTime": [datetime.now(timezone.utc)], + "TagName": ["TestTag"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=BATCH_MOCKED_PARAMETER_DICT, + json=BATCH_POST_PAYLOAD_SINGLE_WITH_POST_ERROR_DICT, + ) + + expected = { + "detail": "Incorrectly formatted request provided: All POST requests require a body" + } + + assert actual.json() == expected + assert actual.status_code == 400 + + +async def test_api_batch_multiple_success(mocker: MockerFixture): + """ + Case when single post request supplied in array of correct format + """ + + summary_test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "Count": [10.0], + "Avg": [5.05], + "Min": [1.0], + "Max": [10.0], + "StDev": [3.02], + "Sum": [25.0], + "Var": [0.0], + } + ) + + raw_test_data = pd.DataFrame( + { + "EventTime": [datetime.now(timezone.utc)], + "TagName": ["TestTag"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = None + # add side effect since require batch to return different data after each call + # batch.get return value is array of dfs, so must patch with nested array + mock_patch_side_effect = [[summary_test_data], [raw_test_data]] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + patch_side_effect=mock_patch_side_effect, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=BATCH_MOCKED_PARAMETER_DICT, + json=BATCH_POST_PAYLOAD_MULTIPLE, + ) + + expected = json.loads( + json_response_batch([summary_test_data, raw_test_data]).body.decode("utf-8") + ) + + assert actual.json() == expected + assert actual.status_code == 200 diff --git a/tests/api/v1/test_api_circular_average.py b/tests/api/v1/test_api_circular_average.py index 98a857676..99100c52f 100644 --- a/tests/api/v1/test_api_circular_average.py +++ b/tests/api/v1/test_api_circular_average.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json + +import os import pytest from pytest_mock import MockerFixture import pandas as pd @@ -25,6 +26,9 @@ mocker_setup, TEST_HEADERS, BASE_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_TAG_MAPPING_EMPTY, + MOCK_MAPPING_ENDPOINT_URL, ) from httpx import AsyncClient from src.api.v1 import app @@ -146,3 +150,140 @@ async def test_api_circular_average_post_error(mocker: MockerFixture, api_test_d assert response.status_code == 400 assert actual == '{"detail":"Error Connecting to Database"}' + + +async def test_api_circular_average_get_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "EventTime": [datetime.now(timezone.utc)], + "TagName": ["Tagname1"], + "Value": [1.5], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = CIRCULAR_AVERAGE_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_circular_average_post_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "EventTime": [CIRCULAR_AVERAGE_MOCKED_PARAMETER_DICT["start_date"]], + "TagName": ["Tagname1"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = CIRCULAR_AVERAGE_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=modified_param_dict, + json=CIRCULAR_AVERAGE_POST_BODY_MOCKED_PARAMETER_DICT, + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_circular_average_get_lookup_no_tag_map_error(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "EventTime": [CIRCULAR_AVERAGE_MOCKED_PARAMETER_DICT["start_date"]], + "TagName": ["Tagname1"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_EMPTY, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = CIRCULAR_AVERAGE_MOCKED_PARAMETER_DICT.copy() + modified_param_dict["tagname"] = ["NonExistentTag"] + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = '{"detail":"One or more tags do not have tables associated with them, the data belongs to a confidential table, or you do not have access. If the tag belongs to a confidential table and you do have access, please supply the business_unit, asset, data_security_level and data_type"}' + + assert actual.text == expected + assert actual.status_code == 400 diff --git a/tests/api/v1/test_api_common.py b/tests/api/v1/test_api_common.py new file mode 100644 index 000000000..b2d014046 --- /dev/null +++ b/tests/api/v1/test_api_common.py @@ -0,0 +1,372 @@ +# Copyright 2024 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import patch +import json +import pandas as pd +import numpy as np +from datetime import datetime, timezone +from src.api.v1.common import ( + lookup_before_get, + query_mapping_endpoint, + split_table_name, + concatenate_dfs_and_order, + json_response_batch, +) +from src.sdk.python.rtdip_sdk.connectors import DatabricksSQLConnection +from src.sdk.python.rtdip_sdk.queries.time_series import raw + +from tests.api.v1.api_test_objects import ( + RAW_MOCKED_PARAMETER_DICT, + MOCK_MAPPING_ENDPOINT_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_TAG_MAPPING_MULTIPLE, + MOCK_TAG_MAPPING_EMPTY, + mocker_setup, +) + +############################### +# Mocker set-ups +############################### +MOCK_METHOD = "src.sdk.python.rtdip_sdk.queries.time_series.raw.get" +MOCK_BATCH_METHOD = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + + +############################### +# Tests for lookup_before_get +############################### +def test_api_lookup_before_get(mocker): + # parameters dict + test_parameters = RAW_MOCKED_PARAMETER_DICT + test_parameters["tag_names"] = ["Tagname1", "Tagname2", "Tagname3"] + + # Mock get, but for each time called provides next result + mock_get_data = [ + # Two tags in one table + pd.DataFrame( + { + "EventTime": [ + RAW_MOCKED_PARAMETER_DICT["start_date"], + RAW_MOCKED_PARAMETER_DICT["start_date"], + ], + "TagName": ["Tagname1", "Tagname2"], + "Status": ["Good", "Good"], + "Value": [1.01, 2.02], + } + ), + # One tag in another + pd.DataFrame( + { + "EventTime": [RAW_MOCKED_PARAMETER_DICT["end_date"]], + "TagName": ["Tagname3"], + "Status": ["Good"], + "Value": [3.03], + } + ), + ] + + # Set-up mocker + mocker = mocker_setup( + mocker, + MOCK_METHOD, + mock_get_data, + patch_side_effect=mock_get_data, + tag_mapping_data=MOCK_TAG_MAPPING_MULTIPLE, + ) + mocker.patch(MOCK_BATCH_METHOD, return_value=mock_get_data) + + # Get result from lookup_before_get function + connection = DatabricksSQLConnection( + access_token="token", server_hostname="test", http_path="test" + ) + actual = lookup_before_get("raw", connection, test_parameters) + + # Define expected result + expected = pd.DataFrame( + { + "EventTime": [ + RAW_MOCKED_PARAMETER_DICT["start_date"], + RAW_MOCKED_PARAMETER_DICT["start_date"], + RAW_MOCKED_PARAMETER_DICT["end_date"], + ], + "TagName": ["Tagname1", "Tagname2", "Tagname3"], + "Status": ["Good", "Good", "Good"], + "Value": [1.01, 2.02, 3.03], + } + ) + + # Assert equality + pd.testing.assert_frame_equal(actual, expected, check_dtype=True) + + +############################### +# Tests for query_mapping_endpoint +############################### + + +def test_api_common_query_mapping_endpoint(mocker): + # Set-up mocker + mocker_setup( + mocker, MOCK_METHOD, test_data={}, tag_mapping_data=MOCK_TAG_MAPPING_MULTIPLE + ) + + # Run the function + tags = ["Tagname1", "Tagname2"] + connection = DatabricksSQLConnection( + access_token="token", server_hostname="test", http_path="test" + ) + actual = query_mapping_endpoint( + tags, MOCK_MAPPING_ENDPOINT_URL, connection=connection + ) + + expected = { + "rtdip.sensors.asset1_restricted_events_float": ["Tagname1", "Tagname2"], + "rtdip.sensors.asset2_restricted_events_integer": ["Tagname3"], + } + + assert actual == expected + + +############################### +# Tests for splitTablename +############################### +def test_api_common_split_table_name(): + """Tests for splitting table name into dict of business_unit, asset etc""" + + actual_with_expected_format = split_table_name( + "test.sensors.asset_restricted_events_float" + ) + expected_with_expected_format = { + "business_unit": "test", + "asset": "asset", + "data_security_level": "restricted", + "data_type": "float", + } + + with pytest.raises(Exception) as actual_with_incorrect_format_missing: + split_table_name("test") + + with pytest.raises(Exception) as actual_with_incorrect_schema: + split_table_name("test.schema.asset_restricted_events_float") + + expected_with_incorrect_format_message = "Unsupported table name format supplied. Please use the format 'businessunit.schema.asset.datasecurityevel_events_datatype" + + assert actual_with_expected_format == expected_with_expected_format + assert ( + actual_with_incorrect_format_missing.value.args[0] + == expected_with_incorrect_format_message + ) + assert ( + actual_with_incorrect_schema.value.args[0] + == expected_with_incorrect_format_message + ) + + +############################### +# Tests for concatenate_dfs_and_order +############################### + +test_df1 = pd.DataFrame( + { + "EventTime": [ + "01/01/2024 14:00", + "01/01/2024 15:00", + ], + "TagName": ["TestTag1", "TestTag2"], + "Status": ["Good", "Good"], + "Value": [1.01, 2.02], + } +) + +test_df2 = pd.DataFrame( + { + "EventTime": ["01/01/2024 14:00"], + "TagName": ["TestTag3"], + "Status": ["Good"], + "Value": [3.03], + } +) + +test_df3_pivoted = pd.DataFrame( + { + "EventTime": [ + "01/01/2024 14:00", + "01/01/2024 15:00", + ], + "TestTag1": [1.01, 5.05], + "TestTag2": [2.02, 6.05], + } +) + +test_df4_pivoted = pd.DataFrame( + { + "EventTime": ["01/01/2024 14:00", "01/01/2024 15:00"], + "TestTag3": [4.04, 7.07], + } +) + + +def test_api_common_concatenate_dfs_and_order_unpivoted(): + """Tests unpivoted concatenation of dfs""" + + actual = concatenate_dfs_and_order( + dfs_arr=[test_df1, test_df2], + tags=["TestTag1", "TestTag2", "TestTag3"], + pivot=False, + ) + + expected = pd.DataFrame( + { + "EventTime": ["01/01/2024 14:00", "01/01/2024 15:00", "01/01/2024 14:00"], + "TagName": ["TestTag1", "TestTag2", "TestTag3"], + "Status": ["Good", "Good", "Good"], + "Value": [1.01, 2.02, 3.03], + } + ) + + pd.testing.assert_frame_equal(actual, expected, check_dtype=True) + + +def test_api_common_concatenate_dfs_and_order_pivoted(): + """Tests pivoted concatenation of dfs, which adds columns""" + + actual = concatenate_dfs_and_order( + dfs_arr=[test_df3_pivoted, test_df4_pivoted], + tags=["TestTag1", "TestTag2", "TestTag3"], + pivot=True, + ) + + expected = pd.DataFrame( + { + "EventTime": ["01/01/2024 14:00", "01/01/2024 15:00"], + "TestTag1": [1.01, 5.05], + "TestTag2": [2.02, 6.05], + "TestTag3": [4.04, 7.07], + } + ) + + pd.testing.assert_frame_equal(actual, expected, check_dtype=True) + + +def test_api_common_concatenate_dfs_and_order_pivoted_ordering(): + """Tests pivoted concatenation of dfs, with specific tag ordering""" + + actual = concatenate_dfs_and_order( + dfs_arr=[test_df3_pivoted, test_df4_pivoted], + tags=["TestTag2", "TestTag1", "TestTag3"], + pivot=True, + ) + + expected = pd.DataFrame( + { + "EventTime": ["01/01/2024 14:00", "01/01/2024 15:00"], + "TestTag2": [2.02, 6.05], + "TestTag1": [1.01, 5.05], + "TestTag3": [4.04, 7.07], + } + ) + + pd.testing.assert_frame_equal(actual, expected, check_dtype=True) + + +############################### +# Tests for json_response_batch +############################### +def test_api_common_json_response_batch(): + """Tests that should correctly combine list of dfs into a json response""" + + summary_test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "Count": [10.0], + "Avg": [5.05], + "Min": [1.0], + "Max": [10.0], + "StDev": [3.02], + "Sum": [25.0], + "Var": [0.0], + } + ) + + raw_test_data = pd.DataFrame( + { + "EventTime": ["2024-06-27T15:35", "2024-06-27T15:45"], + "TagName": ["TestTag", "TestTag"], + "Status": ["Good", "Good"], + "Value": [1.01, 5.55], + } + ) + + actual = json_response_batch([summary_test_data, raw_test_data]) + + expected = { + "data": [ + { + "schema": { + "fields": [ + {"name": "TagName", "type": "string"}, + {"name": "Count", "type": "number"}, + {"name": "Avg", "type": "number"}, + {"name": "Min", "type": "number"}, + {"name": "Max", "type": "number"}, + {"name": "StDev", "type": "number"}, + {"name": "Sum", "type": "number"}, + {"name": "Var", "type": "number"}, + ], + "primaryKey": False, + "pandas_version": "1.4.0", + }, + "data": [ + { + "TagName": "TestTag", + "Count": 10.0, + "Avg": 5.05, + "Min": 1.0, + "Max": 10.0, + "StDev": 3.02, + "Sum": 25.0, + "Var": 0.0, + } + ], + }, + { + "schema": { + "fields": [ + {"name": "EventTime", "type": "string"}, + {"name": "TagName", "type": "string"}, + {"name": "Status", "type": "string"}, + {"name": "Value", "type": "number"}, + ], + "primaryKey": False, + "pandas_version": "1.4.0", + }, + "data": [ + { + "EventTime": "2024-06-27T15:35", + "TagName": "TestTag", + "Status": "Good", + "Value": 1.01, + }, + { + "EventTime": "2024-06-27T15:45", + "TagName": "TestTag", + "Status": "Good", + "Value": 5.55, + }, + ], + }, + ] + } + assert json.loads(actual.body) == expected diff --git a/tests/api/v1/test_api_latest.py b/tests/api/v1/test_api_latest.py index 9e9fe344d..1bc3b29a3 100644 --- a/tests/api/v1/test_api_latest.py +++ b/tests/api/v1/test_api_latest.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest from pytest_mock import MockerFixture import pandas as pd @@ -24,6 +25,9 @@ mocker_setup, TEST_HEADERS, BASE_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_TAG_MAPPING_EMPTY, + MOCK_MAPPING_ENDPOINT_URL, ) from httpx import AsyncClient from src.api.v1 import app @@ -197,3 +201,153 @@ async def test_api_raw_post_error(mocker: MockerFixture, api_test_data): assert response.status_code == 400 assert actual == '{"detail":"Error Connecting to Database"}' + + +async def test_api_latest_get_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "EventTime": [datetime.now(timezone.utc)], + "Status": ["Good"], + "Value": ["1.01"], + "ValueType": ["string"], + "GoodEventTime": [datetime.now(timezone.utc)], + "GoodValue": ["1.01"], + "GoodValueType": ["string"], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_latest_post_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "EventTime": [datetime.now(timezone.utc)], + "Status": ["Good"], + "Value": ["1.01"], + "ValueType": ["string"], + "GoodEventTime": [datetime.now(timezone.utc)], + "GoodValue": ["1.01"], + "GoodValueType": ["string"], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=modified_param_dict, + json=METADATA_POST_BODY_MOCKED_PARAMETER_DICT, + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_latest_get_lookup_no_tag_map_error(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "TagName": ["TestTag"], + "EventTime": [datetime.now(timezone.utc)], + "Status": ["Good"], + "Value": ["1.01"], + "ValueType": ["string"], + "GoodEventTime": [datetime.now(timezone.utc)], + "GoodValue": ["1.01"], + "GoodValueType": ["string"], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_EMPTY, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + modified_param_dict["tagname"] = ["NonExistentTag"] + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = '{"detail":"One or more tags do not have tables associated with them, the data belongs to a confidential table, or you do not have access. If the tag belongs to a confidential table and you do have access, please supply the business_unit, asset, data_security_level and data_type"}' + + assert actual.text == expected + assert actual.status_code == 400 diff --git a/tests/api/v1/test_api_metadata.py b/tests/api/v1/test_api_metadata.py index c66fc3a49..966014ecb 100644 --- a/tests/api/v1/test_api_metadata.py +++ b/tests/api/v1/test_api_metadata.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest from pytest_mock import MockerFixture import pandas as pd @@ -23,12 +24,18 @@ mocker_setup, TEST_HEADERS, BASE_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_TAG_MAPPING_EMPTY, + MOCK_MAPPING_ENDPOINT_URL, ) from httpx import AsyncClient from src.api.v1 import app MOCK_METHOD = "src.sdk.python.rtdip_sdk.queries.metadata.get" MOCK_API_NAME = "/api/v1/metadata" +TEST_DATA = pd.DataFrame( + {"TagName": ["TestTag"], "UoM": ["UoM1"], "Description": ["Test Description"]} +) pytestmark = pytest.mark.anyio @@ -181,3 +188,114 @@ async def test_api_metadata_post_error(mocker: MockerFixture, api_test_data): assert response.status_code == 400 assert actual == '{"detail":"Error Connecting to Database"}' + + +async def test_api_metadata_get_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [TEST_DATA] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = TEST_DATA.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_metadata_post_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [TEST_DATA] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=modified_param_dict, + json=METADATA_POST_BODY_MOCKED_PARAMETER_DICT, + ) + + expected = TEST_DATA.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.replace(',"tz":"UTC"', "").rstrip("}") + + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_metadata_get_lookup_no_tag_map_error(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [TEST_DATA] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_EMPTY, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = METADATA_MOCKED_PARAMETER_DICT.copy() + modified_param_dict["tagname"] = ["NonExistentTag"] + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = '{"detail":"One or more tags do not have tables associated with them, the data belongs to a confidential table, or you do not have access. If the tag belongs to a confidential table and you do have access, please supply the business_unit, asset, data_security_level and data_type"}' + + assert actual.text == expected + assert actual.status_code == 400 diff --git a/tests/api/v1/test_api_raw.py b/tests/api/v1/test_api_raw.py index 4430e9bc7..afde6d60b 100644 --- a/tests/api/v1/test_api_raw.py +++ b/tests/api/v1/test_api_raw.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pytest from pytest_mock import MockerFixture from tests.api.v1.api_test_objects import ( @@ -22,8 +23,12 @@ mocker_setup, TEST_HEADERS, BASE_URL, + MOCK_TAG_MAPPING_SINGLE, + MOCK_TAG_MAPPING_EMPTY, + MOCK_MAPPING_ENDPOINT_URL, ) from pandas.io.json import build_table_schema +import pandas as pd from httpx import AsyncClient from src.api.v1 import app @@ -134,3 +139,140 @@ async def test_api_raw_post_error(mocker: MockerFixture, api_test_data): assert response.status_code == 400 assert actual == '{"detail":"Error Connecting to Database"}' + + +async def test_api_raw_get_lookup_success(mocker: MockerFixture, api_test_data): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "EventTime": [RAW_MOCKED_PARAMETER_DICT["start_date"]], + "TagName": ["Tagname1"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = RAW_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.rstrip("}") + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_raw_post_lookup_success(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + """ + + test_data = pd.DataFrame( + { + "EventTime": [RAW_MOCKED_PARAMETER_DICT["start_date"]], + "TagName": ["Tagname1"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_SINGLE, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup + modified_param_dict = RAW_POST_MOCKED_PARAMETER_DICT.copy() + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.post( + MOCK_API_NAME, + headers=TEST_HEADERS, + params=modified_param_dict, + json=RAW_POST_BODY_MOCKED_PARAMETER_DICT, + ) + + expected = test_data.to_json(orient="table", index=False, date_unit="ns") + expected = ( + expected.rstrip("}") + ',"pagination":{"limit":null,"offset":null,"next":null}}' + ) + + assert actual.text == expected + assert actual.status_code == 200 + + +async def test_api_raw_get_lookup_no_tag_map_error(mocker: MockerFixture): + """ + Case when no business_unit, asset etc supplied so instead invokes tag lookup + AND there is no table associated with the tag which results in error. + """ + + test_data = pd.DataFrame( + { + "EventTime": [RAW_MOCKED_PARAMETER_DICT["start_date"]], + "TagName": ["Tagname1"], + "Status": ["Good"], + "Value": [1.01], + } + ) + + # Mock the batch method, which outputs test data in the form of an array of dfs + mock_method = "src.sdk.python.rtdip_sdk.queries.time_series.batch.get" + mock_method_return_data = [test_data] + mocker = mocker_setup( + mocker, + mock_method, + mock_method_return_data, + tag_mapping_data=MOCK_TAG_MAPPING_EMPTY, + ) + mocker.patch.dict( + os.environ, {"DATABRICKS_SERVING_ENDPOINT": MOCK_MAPPING_ENDPOINT_URL} + ) + + # Remove parameters so that runs lookup, and add tag that does not exist + modified_param_dict = RAW_MOCKED_PARAMETER_DICT.copy() + modified_param_dict["tagname"] = ["NonExistentTag"] + del modified_param_dict["business_unit"] + + async with AsyncClient(app=app, base_url=BASE_URL) as ac: + actual = await ac.get( + MOCK_API_NAME, headers=TEST_HEADERS, params=modified_param_dict + ) + + expected = '{"detail":"One or more tags do not have tables associated with them, the data belongs to a confidential table, or you do not have access. If the tag belongs to a confidential table and you do have access, please supply the business_unit, asset, data_security_level and data_type"}' + + assert actual.text == expected + assert actual.status_code == 400 diff --git a/tests/sdk/python/rtdip_sdk/pipelines/monitoring/spark/data_quality/test_great_expectations_data_quality.py b/tests/sdk/python/rtdip_sdk/pipelines/monitoring/spark/data_quality/test_great_expectations_data_quality.py index 69218c439..00bb57902 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/monitoring/spark/data_quality/test_great_expectations_data_quality.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/monitoring/spark/data_quality/test_great_expectations_data_quality.py @@ -42,7 +42,6 @@ def test_create_expectations(mocker: MockerFixture): def test_build_expectations(): - expectation_type = "expect_column_values_to_not_be_null" exception_dict = { "column": "user_id",