Skip to content

Commit

Permalink
API Memory Improvements (rtdip#782)
Browse files Browse the repository at this point in the history
* API Memory Management Updates

Signed-off-by: GBBBAS <[email protected]>

* Remove Unit Test

Signed-off-by: GBBBAS <[email protected]>

* Test Updates

Signed-off-by: GBBBAS <[email protected]>

* Updates for String data return type

Signed-off-by: GBBBAS <[email protected]>

* Update for docs

Signed-off-by: GBBBAS <[email protected]>

* Update to docs type

Signed-off-by: GBBBAS <[email protected]>

* Update to type

Signed-off-by: GBBBAS <[email protected]>

* Update for query tests

Signed-off-by: GBBBAS <[email protected]>

---------

Signed-off-by: GBBBAS <[email protected]>
  • Loading branch information
GBBBAS authored Jul 23, 2024
1 parent 4c070a4 commit def7fe0
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 46 deletions.
25 changes: 14 additions & 11 deletions src/api/v1/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
from fastapi import Response
import dateutil.parser
from pandas import DataFrame
import pyarrow as pa
from pandas.io.json import build_table_schema
from src.sdk.python.rtdip_sdk.connectors import DatabricksSQLConnection
from src.sdk.python.rtdip_sdk.connectors import (
DatabricksSQLConnection,
ConnectionReturnType,
)

if importlib.util.find_spec("turbodbc") != None:
from src.sdk.python.rtdip_sdk.connectors import TURBODBCSQLConnection
Expand Down Expand Up @@ -69,12 +73,14 @@ def common_api_setup_tasks( # NOSONAR
databricks_server_host_name,
databricks_http_path,
token,
ConnectionReturnType.String,
)
else:
connection = DatabricksSQLConnection(
databricks_server_host_name,
databricks_http_path,
token,
ConnectionReturnType.String,
)

parameters = base_query_parameters.__dict__
Expand Down Expand Up @@ -136,7 +142,7 @@ def common_api_setup_tasks( # NOSONAR
return connection, parameters


def pagination(limit_offset_parameters: LimitOffsetQueryParams, data: DataFrame):
def pagination(limit_offset_parameters: LimitOffsetQueryParams, rows: int):
pagination = PaginationRow(
limit=None,
offset=None,
Expand All @@ -150,7 +156,7 @@ def pagination(limit_offset_parameters: LimitOffsetQueryParams, data: DataFrame)
next_offset = None

if (
len(data.index) == limit_offset_parameters.limit
rows == limit_offset_parameters.limit
and limit_offset_parameters.offset is not None
):
next_offset = limit_offset_parameters.offset + limit_offset_parameters.limit
Expand Down Expand Up @@ -178,11 +184,11 @@ def datetime_parser(json_dict):


def json_response(
data: DataFrame, limit_offset_parameters: LimitOffsetQueryParams
data: dict, limit_offset_parameters: LimitOffsetQueryParams
) -> Response:
schema_df = pd.DataFrame()
if not data.empty:
json_str = data.loc[0, "Value"]
if data["data"] is not None and data["data"] != "":
json_str = data["data"][0 : data["data"].find("}") + 1]
json_dict = json.loads(json_str, object_hook=datetime_parser)
schema_df = pd.json_normalize(json_dict)

Expand All @@ -192,11 +198,8 @@ def json_response(
FieldSchema.model_validate(
build_table_schema(schema_df, index=False, primary_key=False),
).model_dump_json(),
"[" + ",".join(data["Value"]) + "]",
# data.replace({np.nan: None}).to_json(
# orient="records", date_format="iso", date_unit="ns"
# ),
pagination(limit_offset_parameters, data).model_dump_json(),
"[" + data["data"] + "]",
pagination(limit_offset_parameters, data["count"]).model_dump_json(),
)
+ "}",
media_type="application/json",
Expand Down
1 change: 0 additions & 1 deletion src/api/v1/time_weighted_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def time_weighted_average_events_get(
)

data = time_weighted_average.get(connection, parameters)
data = data.reset_index()

return json_response(data, limit_offset_parameters)
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions src/sdk/python/rtdip_sdk/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
if importlib.util.find_spec("pyspark") != None:
from .grpc.spark_connector import *
from .llm.chatopenai_databricks_connector import *
from .models import *
22 changes: 22 additions & 0 deletions src/sdk/python/rtdip_sdk/connectors/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


class ConnectionReturnType(str, Enum):
Pandas = "pandas"
Pyarrow = "pyarrow"
List = "list"
String = "string"
57 changes: 48 additions & 9 deletions src/sdk/python/rtdip_sdk/connectors/odbc/db_sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from databricks import sql
import pyarrow as pa
import polars as pl
from ..connection_interface import ConnectionInterface
from ..cursor_interface import CursorInterface
from ..models import ConnectionReturnType
import logging


Expand All @@ -32,10 +35,17 @@ class DatabricksSQLConnection(ConnectionInterface):
access_token: Azure AD or Databricks PAT token
"""

def __init__(self, server_hostname: str, http_path: str, access_token: str) -> None:
def __init__(
self,
server_hostname: str,
http_path: str,
access_token: str,
return_type=ConnectionReturnType.Pandas,
) -> None:
self.server_hostname = server_hostname
self.http_path = http_path
self.access_token = access_token
self.return_type = return_type
# call auth method
self.connection = self._connect()

Expand Down Expand Up @@ -70,7 +80,7 @@ def cursor(self) -> object:
try:
if self.connection.open == False:
self.connection = self._connect()
return DatabricksSQLCursor(self.connection.cursor())
return DatabricksSQLCursor(self.connection.cursor(), self.return_type)
except Exception as e:
logging.exception("error with cursor object")
raise e
Expand All @@ -84,8 +94,9 @@ class DatabricksSQLCursor(CursorInterface):
cursor: controls execution of commands on cluster or SQL Warehouse
"""

def __init__(self, cursor: object) -> None:
def __init__(self, cursor: object, return_type=ConnectionReturnType.Pandas) -> None:
self.cursor = cursor
self.return_type = return_type

def execute(self, query: str) -> None:
"""
Expand All @@ -100,7 +111,7 @@ def execute(self, query: str) -> None:
logging.exception("error while executing the query")
raise e

def fetch_all(self, fetch_size=5_000_000) -> list:
def fetch_all(self, fetch_size=5_000_000) -> Union[list, dict]:
"""
Gets all rows of a query.
Expand All @@ -109,16 +120,44 @@ def fetch_all(self, fetch_size=5_000_000) -> list:
"""
try:
get_next_result = True
results = []
results = None if self.return_type == ConnectionReturnType.String else []
count = 0
while get_next_result:
result = self.cursor.fetchmany_arrow(fetch_size)
results.append(result)
count += result.num_rows
if self.return_type == ConnectionReturnType.List:
column_list = []
for column in result.columns:
column_list.append(column.to_pylist())
results.extend(zip(*column_list))
elif self.return_type == ConnectionReturnType.String:
column_list = []
for column in result.columns:
column_list.append(column.to_pylist())

strings = ",".join([str(item[0]) for item in zip(*column_list)])
if results is None:
results = strings
else:
results = ",".join([results, strings])
else:
results.append(result)
if result.num_rows < fetch_size:
get_next_result = False

pyarrow_table = pa.concat_tables(results)
df = pyarrow_table.to_pandas()
return df
if self.return_type == ConnectionReturnType.Pandas:
pyarrow_table = pa.concat_tables(results)
return pyarrow_table.to_pandas()
elif self.return_type == ConnectionReturnType.Pyarrow:
pyarrow_table = pa.concat_tables(results)
return pyarrow_table
elif self.return_type == ConnectionReturnType.List:
return results
elif self.return_type == ConnectionReturnType.String:
return {
"data": results,
"count": count,
}
except Exception as e:
logging.exception("error while fetching the rows of a query")
raise e
Expand Down
23 changes: 18 additions & 5 deletions src/sdk/python/rtdip_sdk/connectors/odbc/turbodbc_sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..._sdk_utils.compare_versions import _package_version_meets_minimum
from ..connection_interface import ConnectionInterface
from ..cursor_interface import CursorInterface
from ..models import ConnectionReturnType
import logging
import os

Expand All @@ -37,11 +38,18 @@ class TURBODBCSQLConnection(ConnectionInterface):
More fields such as driver can be configured upon extension.
"""

def __init__(self, server_hostname: str, http_path: str, access_token: str) -> None:
def __init__(
self,
server_hostname: str,
http_path: str,
access_token: str,
return_type=ConnectionReturnType.Pandas,
) -> None:
_package_version_meets_minimum("turbodbc", "4.0.0")
self.server_hostname = server_hostname
self.http_path = http_path
self.access_token = access_token
self.return_type = return_type
# call auth method
self.connection = self._connect()
self.open = True
Expand Down Expand Up @@ -97,7 +105,9 @@ def cursor(self) -> object:
try:
if self.open == False:
self.connection = self._connect()
return TURBODBCSQLCursor(self.connection.cursor())
return TURBODBCSQLCursor(
self.connection.cursor(), return_type=self.return_type
)
except Exception as e:
logging.exception("error with cursor object")
raise e
Expand All @@ -111,8 +121,9 @@ class TURBODBCSQLCursor(CursorInterface):
cursor: controls execution of commands on cluster or SQL Warehouse
"""

def __init__(self, cursor: object) -> None:
def __init__(self, cursor: object, return_type=ConnectionReturnType.Pandas) -> None:
self.cursor = cursor
self.return_type = return_type

def execute(self, query: str) -> None:
"""
Expand All @@ -136,8 +147,10 @@ def fetch_all(self) -> list:
"""
try:
result = self.cursor.fetchallarrow()
df = result.to_pandas()
return df
if self.return_type == ConnectionReturnType.Pyarrow:
return result
elif self.return_type == ConnectionReturnType.Pandas:
return result.to_pandas()
except Exception as e:
logging.exception("error while fetching the rows from the query")
raise e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_default_package(package_name):
"aws_boto3": PyPiLibrary(name="boto3", version="1.28.2"),
"hashicorp_vault": PyPiLibrary(name="hvac", version="1.1.0"),
"api_requests": PyPiLibrary(name="requests", version="2.30.0"),
"pyarrow": PyPiLibrary(name="pyarrow", version="12.0.0"),
"pyarrow": PyPiLibrary(name="pyarrow", version="14.0.2"),
"pandas": PyPiLibrary(name="pandas", version="2.0.1"),
}
return DEFAULT_PACKAGES[package_name]
1 change: 1 addition & 0 deletions tests/api/v1/test_api_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import numpy as np
from datetime import datetime, timezone
from src.sdk.python.rtdip_sdk.authentication.azure import DefaultAuth
from tests.api.v1.api_test_objects import (
RAW_MOCKED_PARAMETER_DICT,
RAW_MOCKED_PARAMETER_ERROR_DICT,
Expand Down
36 changes: 18 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def api_test_data():
}
mock_raw_data = test_raw_data.copy()
mock_raw_data["EventTime"] = mock_raw_data["EventTime"].strftime(datetime_format)
mock_raw_df = pd.DataFrame(
{"Value": [json.dumps(mock_raw_data, separators=(",", ":"))]}
)
mock_raw_df = {"data": json.dumps(mock_raw_data, separators=(",", ":")), "count": 1}
expected_raw = expected_result(test_raw_data)

# Mock Aggregated Data
Expand All @@ -112,9 +110,7 @@ def api_test_data():
}
mock_agg_data = test_agg_data.copy()
mock_agg_data["EventTime"] = mock_agg_data["EventTime"].strftime(datetime_format)
mock_agg_df = pd.DataFrame(
{"Value": [json.dumps(mock_agg_data, separators=(",", ":"))]}
)
mock_agg_df = {"data": json.dumps(mock_agg_data, separators=(",", ":")), "count": 1}
expected_agg = expected_result(test_agg_data)

# Summary Data
Expand All @@ -131,9 +127,10 @@ def api_test_data():

mock_plot_data = test_plot_data.copy()
mock_plot_data["EventTime"] = mock_plot_data["EventTime"].strftime(datetime_format)
mock_plot_df = pd.DataFrame(
{"Value": [json.dumps(mock_plot_data, separators=(",", ":"))]}
)
mock_plot_df = {
"data": json.dumps(mock_plot_data, separators=(",", ":")),
"count": 1,
}
expected_plot = expected_result(test_plot_data)

test_summary_data = {
Expand All @@ -147,9 +144,10 @@ def api_test_data():
"Var": 0.0,
}

mock_summary_df = pd.DataFrame(
{"Value": [json.dumps(test_summary_data, separators=(",", ":"))]}
)
mock_summary_df = {
"data": json.dumps(test_summary_data, separators=(",", ":")),
"count": 1,
}
expected_summary = expected_result(test_summary_data)

test_metadata = {
Expand All @@ -158,9 +156,10 @@ def api_test_data():
"Description": "Test Description",
}

mock_metadata_df = pd.DataFrame(
{"Value": [json.dumps(test_metadata, separators=(",", ":"))]}
)
mock_metadata_df = {
"data": json.dumps(test_metadata, separators=(",", ":")),
"count": 1,
}
expected_metadata = expected_result(test_metadata)

test_latest_data = {
Expand All @@ -181,9 +180,10 @@ def api_test_data():
mock_latest_data["GoodEventTime"] = mock_latest_data["GoodEventTime"].strftime(
datetime_format
)
mock_latest_df = pd.DataFrame(
{"Value": [json.dumps(mock_latest_data, separators=(",", ":"))]}
)
mock_latest_df = {
"data": json.dumps(mock_latest_data, separators=(",", ":")),
"count": 1,
}
expected_latest = expected_result(test_latest_data)

expected_sql = expected_result(test_raw_data, "100", "100")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_ssip_binary_file_to_pcdm_setup():
assert ssip_pi_binary_file_to_pcdm.libraries() == Libraries(
maven_libraries=[],
pypi_libraries=[
PyPiLibrary(name="pyarrow", version="12.0.0", repo=None),
PyPiLibrary(name="pyarrow", version="14.0.2", repo=None),
PyPiLibrary(name="pandas", version="2.0.1", repo=None),
],
pythonwheel_libraries=[],
Expand Down

0 comments on commit def7fe0

Please sign in to comment.