From def7fe0cf0356a43998296280b66b7c685a0e43d Mon Sep 17 00:00:00 2001 From: GBBBAS <42962356+GBBBAS@users.noreply.github.com> Date: Tue, 23 Jul 2024 13:43:15 +0100 Subject: [PATCH] API Memory Improvements (#782) * API Memory Management Updates Signed-off-by: GBBBAS * Remove Unit Test Signed-off-by: GBBBAS * Test Updates Signed-off-by: GBBBAS * Updates for String data return type Signed-off-by: GBBBAS * Update for docs Signed-off-by: GBBBAS * Update to docs type Signed-off-by: GBBBAS * Update to type Signed-off-by: GBBBAS * Update for query tests Signed-off-by: GBBBAS --------- Signed-off-by: GBBBAS --- src/api/v1/common.py | 25 ++++---- src/api/v1/time_weighted_average.py | 1 - .../python/rtdip_sdk/connectors/__init__.py | 1 + src/sdk/python/rtdip_sdk/connectors/models.py | 22 +++++++ .../connectors/odbc/db_sql_connector.py | 57 ++++++++++++++++--- .../connectors/odbc/turbodbc_sql_connector.py | 23 ++++++-- .../pipelines/_pipeline_utils/constants.py | 2 +- tests/api/v1/test_api_raw.py | 1 + tests/conftest.py | 36 ++++++------ .../spark/test_ssip_pi_binary_file_to_pcdm.py | 2 +- 10 files changed, 124 insertions(+), 46 deletions(-) create mode 100644 src/sdk/python/rtdip_sdk/connectors/models.py diff --git a/src/api/v1/common.py b/src/api/v1/common.py index b50e75390..714c66e73 100644 --- a/src/api/v1/common.py +++ b/src/api/v1/common.py @@ -21,8 +21,12 @@ from fastapi import Response import dateutil.parser from pandas import DataFrame +import pyarrow as pa from pandas.io.json import build_table_schema -from src.sdk.python.rtdip_sdk.connectors import DatabricksSQLConnection +from src.sdk.python.rtdip_sdk.connectors import ( + DatabricksSQLConnection, + ConnectionReturnType, +) if importlib.util.find_spec("turbodbc") != None: from src.sdk.python.rtdip_sdk.connectors import TURBODBCSQLConnection @@ -69,12 +73,14 @@ def common_api_setup_tasks( # NOSONAR databricks_server_host_name, databricks_http_path, token, + ConnectionReturnType.String, ) else: connection = DatabricksSQLConnection( databricks_server_host_name, databricks_http_path, token, + ConnectionReturnType.String, ) parameters = base_query_parameters.__dict__ @@ -136,7 +142,7 @@ def common_api_setup_tasks( # NOSONAR return connection, parameters -def pagination(limit_offset_parameters: LimitOffsetQueryParams, data: DataFrame): +def pagination(limit_offset_parameters: LimitOffsetQueryParams, rows: int): pagination = PaginationRow( limit=None, offset=None, @@ -150,7 +156,7 @@ def pagination(limit_offset_parameters: LimitOffsetQueryParams, data: DataFrame) next_offset = None if ( - len(data.index) == limit_offset_parameters.limit + rows == limit_offset_parameters.limit and limit_offset_parameters.offset is not None ): next_offset = limit_offset_parameters.offset + limit_offset_parameters.limit @@ -178,11 +184,11 @@ def datetime_parser(json_dict): def json_response( - data: DataFrame, limit_offset_parameters: LimitOffsetQueryParams + data: dict, limit_offset_parameters: LimitOffsetQueryParams ) -> Response: schema_df = pd.DataFrame() - if not data.empty: - json_str = data.loc[0, "Value"] + if data["data"] is not None and data["data"] != "": + json_str = data["data"][0 : data["data"].find("}") + 1] json_dict = json.loads(json_str, object_hook=datetime_parser) schema_df = pd.json_normalize(json_dict) @@ -192,11 +198,8 @@ def json_response( FieldSchema.model_validate( build_table_schema(schema_df, index=False, primary_key=False), ).model_dump_json(), - "[" + ",".join(data["Value"]) + "]", - # data.replace({np.nan: None}).to_json( - # orient="records", date_format="iso", date_unit="ns" - # ), - pagination(limit_offset_parameters, data).model_dump_json(), + "[" + data["data"] + "]", + pagination(limit_offset_parameters, data["count"]).model_dump_json(), ) + "}", media_type="application/json", diff --git a/src/api/v1/time_weighted_average.py b/src/api/v1/time_weighted_average.py index a9fbdb611..ac1893d11 100644 --- a/src/api/v1/time_weighted_average.py +++ b/src/api/v1/time_weighted_average.py @@ -54,7 +54,6 @@ def time_weighted_average_events_get( ) data = time_weighted_average.get(connection, parameters) - data = data.reset_index() return json_response(data, limit_offset_parameters) except Exception as e: diff --git a/src/sdk/python/rtdip_sdk/connectors/__init__.py b/src/sdk/python/rtdip_sdk/connectors/__init__.py index e52897d02..824c69ba2 100644 --- a/src/sdk/python/rtdip_sdk/connectors/__init__.py +++ b/src/sdk/python/rtdip_sdk/connectors/__init__.py @@ -22,3 +22,4 @@ if importlib.util.find_spec("pyspark") != None: from .grpc.spark_connector import * from .llm.chatopenai_databricks_connector import * +from .models import * diff --git a/src/sdk/python/rtdip_sdk/connectors/models.py b/src/sdk/python/rtdip_sdk/connectors/models.py new file mode 100644 index 000000000..31c8e2f2c --- /dev/null +++ b/src/sdk/python/rtdip_sdk/connectors/models.py @@ -0,0 +1,22 @@ +# Copyright 2024 RTDIP +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ConnectionReturnType(str, Enum): + Pandas = "pandas" + Pyarrow = "pyarrow" + List = "list" + String = "string" diff --git a/src/sdk/python/rtdip_sdk/connectors/odbc/db_sql_connector.py b/src/sdk/python/rtdip_sdk/connectors/odbc/db_sql_connector.py index c02031ff5..0a4182f51 100644 --- a/src/sdk/python/rtdip_sdk/connectors/odbc/db_sql_connector.py +++ b/src/sdk/python/rtdip_sdk/connectors/odbc/db_sql_connector.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union from databricks import sql import pyarrow as pa +import polars as pl from ..connection_interface import ConnectionInterface from ..cursor_interface import CursorInterface +from ..models import ConnectionReturnType import logging @@ -32,10 +35,17 @@ class DatabricksSQLConnection(ConnectionInterface): access_token: Azure AD or Databricks PAT token """ - def __init__(self, server_hostname: str, http_path: str, access_token: str) -> None: + def __init__( + self, + server_hostname: str, + http_path: str, + access_token: str, + return_type=ConnectionReturnType.Pandas, + ) -> None: self.server_hostname = server_hostname self.http_path = http_path self.access_token = access_token + self.return_type = return_type # call auth method self.connection = self._connect() @@ -70,7 +80,7 @@ def cursor(self) -> object: try: if self.connection.open == False: self.connection = self._connect() - return DatabricksSQLCursor(self.connection.cursor()) + return DatabricksSQLCursor(self.connection.cursor(), self.return_type) except Exception as e: logging.exception("error with cursor object") raise e @@ -84,8 +94,9 @@ class DatabricksSQLCursor(CursorInterface): cursor: controls execution of commands on cluster or SQL Warehouse """ - def __init__(self, cursor: object) -> None: + def __init__(self, cursor: object, return_type=ConnectionReturnType.Pandas) -> None: self.cursor = cursor + self.return_type = return_type def execute(self, query: str) -> None: """ @@ -100,7 +111,7 @@ def execute(self, query: str) -> None: logging.exception("error while executing the query") raise e - def fetch_all(self, fetch_size=5_000_000) -> list: + def fetch_all(self, fetch_size=5_000_000) -> Union[list, dict]: """ Gets all rows of a query. @@ -109,16 +120,44 @@ def fetch_all(self, fetch_size=5_000_000) -> list: """ try: get_next_result = True - results = [] + results = None if self.return_type == ConnectionReturnType.String else [] + count = 0 while get_next_result: result = self.cursor.fetchmany_arrow(fetch_size) - results.append(result) + count += result.num_rows + if self.return_type == ConnectionReturnType.List: + column_list = [] + for column in result.columns: + column_list.append(column.to_pylist()) + results.extend(zip(*column_list)) + elif self.return_type == ConnectionReturnType.String: + column_list = [] + for column in result.columns: + column_list.append(column.to_pylist()) + + strings = ",".join([str(item[0]) for item in zip(*column_list)]) + if results is None: + results = strings + else: + results = ",".join([results, strings]) + else: + results.append(result) if result.num_rows < fetch_size: get_next_result = False - pyarrow_table = pa.concat_tables(results) - df = pyarrow_table.to_pandas() - return df + if self.return_type == ConnectionReturnType.Pandas: + pyarrow_table = pa.concat_tables(results) + return pyarrow_table.to_pandas() + elif self.return_type == ConnectionReturnType.Pyarrow: + pyarrow_table = pa.concat_tables(results) + return pyarrow_table + elif self.return_type == ConnectionReturnType.List: + return results + elif self.return_type == ConnectionReturnType.String: + return { + "data": results, + "count": count, + } except Exception as e: logging.exception("error while fetching the rows of a query") raise e diff --git a/src/sdk/python/rtdip_sdk/connectors/odbc/turbodbc_sql_connector.py b/src/sdk/python/rtdip_sdk/connectors/odbc/turbodbc_sql_connector.py index b1da5b285..30608e420 100644 --- a/src/sdk/python/rtdip_sdk/connectors/odbc/turbodbc_sql_connector.py +++ b/src/sdk/python/rtdip_sdk/connectors/odbc/turbodbc_sql_connector.py @@ -17,6 +17,7 @@ from ..._sdk_utils.compare_versions import _package_version_meets_minimum from ..connection_interface import ConnectionInterface from ..cursor_interface import CursorInterface +from ..models import ConnectionReturnType import logging import os @@ -37,11 +38,18 @@ class TURBODBCSQLConnection(ConnectionInterface): More fields such as driver can be configured upon extension. """ - def __init__(self, server_hostname: str, http_path: str, access_token: str) -> None: + def __init__( + self, + server_hostname: str, + http_path: str, + access_token: str, + return_type=ConnectionReturnType.Pandas, + ) -> None: _package_version_meets_minimum("turbodbc", "4.0.0") self.server_hostname = server_hostname self.http_path = http_path self.access_token = access_token + self.return_type = return_type # call auth method self.connection = self._connect() self.open = True @@ -97,7 +105,9 @@ def cursor(self) -> object: try: if self.open == False: self.connection = self._connect() - return TURBODBCSQLCursor(self.connection.cursor()) + return TURBODBCSQLCursor( + self.connection.cursor(), return_type=self.return_type + ) except Exception as e: logging.exception("error with cursor object") raise e @@ -111,8 +121,9 @@ class TURBODBCSQLCursor(CursorInterface): cursor: controls execution of commands on cluster or SQL Warehouse """ - def __init__(self, cursor: object) -> None: + def __init__(self, cursor: object, return_type=ConnectionReturnType.Pandas) -> None: self.cursor = cursor + self.return_type = return_type def execute(self, query: str) -> None: """ @@ -136,8 +147,10 @@ def fetch_all(self) -> list: """ try: result = self.cursor.fetchallarrow() - df = result.to_pandas() - return df + if self.return_type == ConnectionReturnType.Pyarrow: + return result + elif self.return_type == ConnectionReturnType.Pandas: + return result.to_pandas() except Exception as e: logging.exception("error while fetching the rows from the query") raise e diff --git a/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/constants.py b/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/constants.py index 5af5f2f46..7c0d32112 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/constants.py +++ b/src/sdk/python/rtdip_sdk/pipelines/_pipeline_utils/constants.py @@ -65,7 +65,7 @@ def get_default_package(package_name): "aws_boto3": PyPiLibrary(name="boto3", version="1.28.2"), "hashicorp_vault": PyPiLibrary(name="hvac", version="1.1.0"), "api_requests": PyPiLibrary(name="requests", version="2.30.0"), - "pyarrow": PyPiLibrary(name="pyarrow", version="12.0.0"), + "pyarrow": PyPiLibrary(name="pyarrow", version="14.0.2"), "pandas": PyPiLibrary(name="pandas", version="2.0.1"), } return DEFAULT_PACKAGES[package_name] diff --git a/tests/api/v1/test_api_raw.py b/tests/api/v1/test_api_raw.py index 07385fb88..bd4e64a5c 100644 --- a/tests/api/v1/test_api_raw.py +++ b/tests/api/v1/test_api_raw.py @@ -17,6 +17,7 @@ import pandas as pd import numpy as np from datetime import datetime, timezone +from src.sdk.python.rtdip_sdk.authentication.azure import DefaultAuth from tests.api.v1.api_test_objects import ( RAW_MOCKED_PARAMETER_DICT, RAW_MOCKED_PARAMETER_ERROR_DICT, diff --git a/tests/conftest.py b/tests/conftest.py index d09a0ea5e..7b0f7c624 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,9 +99,7 @@ def api_test_data(): } mock_raw_data = test_raw_data.copy() mock_raw_data["EventTime"] = mock_raw_data["EventTime"].strftime(datetime_format) - mock_raw_df = pd.DataFrame( - {"Value": [json.dumps(mock_raw_data, separators=(",", ":"))]} - ) + mock_raw_df = {"data": json.dumps(mock_raw_data, separators=(",", ":")), "count": 1} expected_raw = expected_result(test_raw_data) # Mock Aggregated Data @@ -112,9 +110,7 @@ def api_test_data(): } mock_agg_data = test_agg_data.copy() mock_agg_data["EventTime"] = mock_agg_data["EventTime"].strftime(datetime_format) - mock_agg_df = pd.DataFrame( - {"Value": [json.dumps(mock_agg_data, separators=(",", ":"))]} - ) + mock_agg_df = {"data": json.dumps(mock_agg_data, separators=(",", ":")), "count": 1} expected_agg = expected_result(test_agg_data) # Summary Data @@ -131,9 +127,10 @@ def api_test_data(): mock_plot_data = test_plot_data.copy() mock_plot_data["EventTime"] = mock_plot_data["EventTime"].strftime(datetime_format) - mock_plot_df = pd.DataFrame( - {"Value": [json.dumps(mock_plot_data, separators=(",", ":"))]} - ) + mock_plot_df = { + "data": json.dumps(mock_plot_data, separators=(",", ":")), + "count": 1, + } expected_plot = expected_result(test_plot_data) test_summary_data = { @@ -147,9 +144,10 @@ def api_test_data(): "Var": 0.0, } - mock_summary_df = pd.DataFrame( - {"Value": [json.dumps(test_summary_data, separators=(",", ":"))]} - ) + mock_summary_df = { + "data": json.dumps(test_summary_data, separators=(",", ":")), + "count": 1, + } expected_summary = expected_result(test_summary_data) test_metadata = { @@ -158,9 +156,10 @@ def api_test_data(): "Description": "Test Description", } - mock_metadata_df = pd.DataFrame( - {"Value": [json.dumps(test_metadata, separators=(",", ":"))]} - ) + mock_metadata_df = { + "data": json.dumps(test_metadata, separators=(",", ":")), + "count": 1, + } expected_metadata = expected_result(test_metadata) test_latest_data = { @@ -181,9 +180,10 @@ def api_test_data(): mock_latest_data["GoodEventTime"] = mock_latest_data["GoodEventTime"].strftime( datetime_format ) - mock_latest_df = pd.DataFrame( - {"Value": [json.dumps(mock_latest_data, separators=(",", ":"))]} - ) + mock_latest_df = { + "data": json.dumps(mock_latest_data, separators=(",", ":")), + "count": 1, + } expected_latest = expected_result(test_latest_data) expected_sql = expected_result(test_raw_data, "100", "100") diff --git a/tests/sdk/python/rtdip_sdk/pipelines/transformers/spark/test_ssip_pi_binary_file_to_pcdm.py b/tests/sdk/python/rtdip_sdk/pipelines/transformers/spark/test_ssip_pi_binary_file_to_pcdm.py index 18d2cc249..a9a49e7c1 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/transformers/spark/test_ssip_pi_binary_file_to_pcdm.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/transformers/spark/test_ssip_pi_binary_file_to_pcdm.py @@ -45,7 +45,7 @@ def test_ssip_binary_file_to_pcdm_setup(): assert ssip_pi_binary_file_to_pcdm.libraries() == Libraries( maven_libraries=[], pypi_libraries=[ - PyPiLibrary(name="pyarrow", version="12.0.0", repo=None), + PyPiLibrary(name="pyarrow", version="14.0.2", repo=None), PyPiLibrary(name="pandas", version="2.0.1", repo=None), ], pythonwheel_libraries=[],