Skip to content
This repository was archived by the owner on Mar 1, 2024. It is now read-only.

Commit

Permalink
python3.11 (#24)
Browse files Browse the repository at this point in the history
* support python 3.11 (#19)

* remove logs in httpx

* enable python 3.11 compatibility

* support WyvernAPI with aiohttp

* create the beta1 version for 0.0.8 (#21)

* use Enum to support both py<3.11 and py3.11 (#22)

* use Enum to support both py<3.11 and py3.11

* fix entity_type typing

* still use str for string enums for backward compatibility

* bump ci python version

* 0.0.8-beta2

* remove print in wyvern api
  • Loading branch information
wintonzheng authored Aug 21, 2023
1 parent c3c3e6a commit 0c4000e
Show file tree
Hide file tree
Showing 20 changed files with 723 additions and 129 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: pre-commit/[email protected]
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
# reference the matrixe python version here.
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"

# Cache the installation of Poetry itself, e.g. the next step. This prevents the workflow
# from installing Poetry every time, which can be slow. Note the use of the Poetry version
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ repos:
- tqdm
- types-tqdm
- nest-asyncio
- aiohttp
exclude: "^tests/"

# Check for spelling
Expand Down
2 changes: 1 addition & 1 deletion examples/real_time_features_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def compute_request_features(
return FeatureData(
identifier=Identifier(
identifier=request.request.request_id,
identifier_type=SimpleIdentifierType.REQUEST,
identifier_type=SimpleIdentifierType.REQUEST.value,
),
features={
"f_number_of_candidates": len(request.request.candidates),
Expand Down
594 changes: 592 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wyvern-ai"
version = "0.0.7"
version = "0.0.8-beta2"
description = ""
authors = ["Wyvern AI <[email protected]>"]
readme = "README.md"
Expand All @@ -9,7 +9,7 @@ packages = [
]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
python = ">=3.8,<3.12"
pydantic = "^1.10.4"
fastapi = "^0.95.2"
uvicorn = "^0.22.0"
Expand All @@ -19,7 +19,6 @@ pyhumps = "^3.8.0"
python-dotenv = "^1.0.0"
pandas = "1.5.3"
feast = {extras = ["redis", "snowflake"], version = "^0.31.1"}
httpx = "^0.24.1"
snowflake-connector-python = "3.0.3"
boto3 = "^1.26.146"
ddtrace = "^1.14.0"
Expand All @@ -30,6 +29,7 @@ tqdm = "^4.65.0"
nest-asyncio = "^1.5.7"
eppo-server-sdk = "^1.2.2"
scipy = "1.10.1"
aiohttp = {extras = ["speedups"], version = "^3.8.5"}


[tool.poetry.group.dev.dependencies]
Expand All @@ -46,6 +46,7 @@ types-boto3 = "^1.0.2"
pyinstrument = "^4.4.0"
pytest-dotenv = "^0.5.2"
ipykernel = "^6.25.0"
aioresponses = "^0.7.4"

[build-system]
requires = ["poetry-core"]
Expand Down
15 changes: 6 additions & 9 deletions tests/feature_store/test_real_time_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,19 @@ def mock_redis(mocker):
"""
Mocks the redis call. Each entry under `return_value` corresponds to a single entity fetch from Redis
"""
with mocker.patch(
mocker.patch(
"wyvern.redis.wyvern_redis.mget",
return_value=[None, None, None, None, None],
):
yield
)


@pytest.fixture
def mock_feature_store(mocker):
with mocker.patch.object(
mocker.patch.object(
feature_store_retrieval_component,
"fetch_features_from_feature_store",
return_value=FeatureMap(feature_map={}),
):
yield
)


@pytest.fixture
Expand Down Expand Up @@ -188,11 +186,10 @@ def mock_redis__2(mocker):
"""
Mocks the redis call. Each entry under `return_value` corresponds to a single entity fetch from Redis
"""
with mocker.patch(
mocker.patch(
"wyvern.redis.wyvern_redis.mget",
return_value=[None, None, None, None, None, None],
):
yield
)


@pytest.mark.asyncio
Expand Down
36 changes: 15 additions & 21 deletions tests/scenarios/test_product_ranking.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
import json
from functools import cached_property
from typing import Any, Dict, List, Optional, Set

import httpx
import pytest
import pytest_asyncio
from aioresponses import aioresponses
from fastapi.testclient import TestClient
from pydantic import BaseModel

Expand All @@ -19,7 +19,9 @@
ModelOutput,
)
from wyvern.components.pipeline_component import PipelineComponent
from wyvern.config import settings
from wyvern.core.compression import wyvern_encode
from wyvern.core.http import aiohttp_client
from wyvern.entities.candidate_entities import CandidateSetEntity
from wyvern.entities.feature_entities import FeatureData, FeatureMap
from wyvern.entities.identifier import Identifier
Expand Down Expand Up @@ -95,24 +97,16 @@
}


@pytest.fixture
def mock_httpx_post(mocker):
mocked_httpx_async_client = httpx.AsyncClient()
with mocker.patch.object(
mocked_httpx_async_client,
"post",
return_value=httpx.Response(
status_code=200,
content=json.dumps(ONLINE_FEATURE_RESPNOSE),
headers={},
json=ONLINE_FEATURE_RESPNOSE,
),
):
with mocker.patch(
"wyvern.components.features.feature_store.httpx_client",
return_value=mocked_httpx_async_client,
):
yield
@pytest_asyncio.fixture
async def mock_http_post(mocker):
with aioresponses() as m:
aiohttp_client.start()
m.post(
f"{settings.WYVERN_FEATURE_STORE_URL}{settings.WYVERN_ONLINE_FEATURES_PATH}",
payload=ONLINE_FEATURE_RESPNOSE,
)
yield
await aiohttp_client.stop()


@pytest.fixture
Expand Down Expand Up @@ -487,7 +481,7 @@ async def test_hydrate__duplicate_brand(mock_redis__duplicate_brand):


@pytest.mark.asyncio
async def test_end_to_end(mock_redis, mock_httpx_post, test_client):
async def test_end_to_end(mock_redis, mock_http_post, test_client):
response = test_client.post(
"/api/v1/product-search-ranking",
json={
Expand Down
53 changes: 37 additions & 16 deletions wyvern/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import asyncio
from typing import Any, Dict, Hashable, List, Optional, Union
from functools import wraps
from typing import Any, Callable, Dict, Hashable, List, Optional, Union

import httpx
import aiohttp
import nest_asyncio
import pandas as pd
import requests
Expand All @@ -18,6 +21,22 @@
RETRY_PER_BATCH = 2


def ensure_async_client(func: Callable) -> Callable:
@wraps(func)
def wrapper(self: WyvernAPI, *args, **kwargs):
if self.async_client.closed:
self.async_client = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=HTTP_TIMEOUT),
)
try:
return func(self, *args, **kwargs)
finally:
if not self.async_client.closed:
asyncio.run(self.async_client.close())

return wrapper


class WyvernAPI:
def __init__(
self,
Expand All @@ -31,7 +50,9 @@ def __init__(
self.headers = {"x-api-key": api_key}
self.base_url = base_url or settings.WYVERN_BASE_URL
self.batch_size = batch_size
self.async_client = httpx.AsyncClient(timeout=HTTP_TIMEOUT)
self.async_client = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=HTTP_TIMEOUT),
)

def get_online_features(
self,
Expand All @@ -52,6 +73,7 @@ def get_online_features(
get_event_timestamps,
)

@ensure_async_client
def get_historical_features(
self,
features: List[str],
Expand Down Expand Up @@ -119,14 +141,13 @@ def get_historical_features(
desc="Fetching historical data",
unit="batch",
)
event_loop = _get_event_loop()
for i in range(num_gathers):
start_idx = i * BATCH_SIZE_PER_GATHER
end_idx = min((i + 1) * BATCH_SIZE_PER_GATHER, num_batches)
retry_count = 0
while retry_count < RETRY_PER_BATCH:
try:
gathered_responses = event_loop.run_until_complete(
gathered_responses = asyncio.run(
self.process_batches(data_batches[start_idx:end_idx]),
)
for response in gathered_responses:
Expand Down Expand Up @@ -192,17 +213,24 @@ async def _send_request_to_wyvern_api_async(
url = f"{self.base_url}{api_path}"
response = await self.async_client.post(url, headers=self.headers, json=data)

if response.status_code != 200:
self._handle_failed_request(response)
if response.status != 200:
await self._handle_failed_async_request(response)

return response.json()
return await response.json()

def _handle_failed_request(
self,
response: Union[httpx.Response, requests.Response],
response: requests.Response,
) -> None:
raise WyvernError(f"Request failed [{response.status_code}]: {response.text}")

async def _handle_failed_async_request(
self,
response: aiohttp.ClientResponse,
) -> None:
text = await response.text()
raise WyvernError(f"Request failed [{response.status}]: {text}")

def _convert_online_features_to_df(
self,
data,
Expand All @@ -227,10 +255,3 @@ def _convert_historical_features_to_df(
) -> pd.DataFrame:
df = pd.DataFrame(data["results"])
return df


def _get_event_loop():
try:
return asyncio.get_running_loop()
except RuntimeError:
return asyncio.new_event_loop()
13 changes: 7 additions & 6 deletions wyvern/components/features/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from wyvern.components.component import Component
from wyvern.config import settings
from wyvern.core.httpx import httpx_client
from wyvern.core.http import aiohttp_client
from wyvern.entities.feature_entities import (
FeatureData,
FeatureMap,
Expand Down Expand Up @@ -66,21 +66,22 @@ async def fetch_features_from_feature_store(
}
# TODO (suchintan) -- chunk + parallelize this
# TODO (Suchintan): This is currently busted in local development
response = await httpx_client().post(
response = await aiohttp_client().post(
f"{self.feature_store_host}{settings.WYVERN_ONLINE_FEATURES_PATH}",
headers=self.request_headers,
json=request_body,
)

if response.status_code != 200:
if response.status != 200:
resp_text = await response.text()
logger.error(
f"Error fetching features from feature store: [{response.status_code}] {response.json()}",
f"Error fetching features from feature store: [{response.status}] {resp_text}",
)
raise WyvernFeatureStoreError(error=response.json())
raise WyvernFeatureStoreError(error=resp_text)

# TODO (suchintan): More graceful response handling here

response_json = response.json()
response_json = await response.json()
feature_names = response_json["metadata"]["feature_names"]
feature_name_keys = [
feature_name.replace("__", ":", 1) for feature_name in feature_names
Expand Down
14 changes: 7 additions & 7 deletions wyvern/components/index/_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def execute(
bulk index entities with redis pipeline
"""

entity_internal_key = f"{input.entity_type}_id"
entity_internal_key = f"{input.entity_type.value}_id"
entity_key: str = input.entity_key or entity_internal_key

entities: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -64,11 +64,11 @@ async def execute(
entity_ids = await wyvern_redis.bulk_index(
entities,
entity_key,
input.entity_type,
input.entity_type.value,
)

return IndexResponse(
entity_type=input.entity_type,
entity_type=input.entity_type.value,
entity_ids=entity_ids,
)

Expand All @@ -85,10 +85,10 @@ async def execute(
input: DeleteEntitiesRequest,
**kwargs,
) -> DeleteEntitiesResponse:
await WyvernIndex.bulk_delete(input.entity_type, input.entity_ids)
await WyvernIndex.bulk_delete(input.entity_type.value, input.entity_ids)
return DeleteEntitiesResponse(
entity_ids=input.entity_ids,
entity_type=input.entity_type,
entity_type=input.entity_type.value,
)


Expand All @@ -105,13 +105,13 @@ async def execute(
**kwargs,
) -> GetEntitiesResponse:
entities = await WyvernEntityIndex.bulk_get(
entity_type=input.entity_type,
entity_type=input.entity_type.value,
entity_ids=input.entity_ids,
)
if len(entities) != len(input.entity_ids):
raise WyvernError("Unexpected Error")
entity_map = {input.entity_ids[i]: entities[i] for i in range(len(entities))}
return GetEntitiesResponse(
entity_type=input.entity_type,
entity_type=input.entity_type.value,
entities=entity_map,
)
4 changes: 2 additions & 2 deletions wyvern/components/models/modelbit_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ModelComponent,
)
from wyvern.config import settings
from wyvern.core.httpx import httpx_client
from wyvern.core.http import aiohttp_client
from wyvern.entities.identifier import Identifier
from wyvern.entities.identifier_entities import WyvernEntity
from wyvern.entities.request import BaseWyvernRequest
Expand Down Expand Up @@ -109,7 +109,7 @@ async def inference(self, input: MODEL_INPUT, **kwargs) -> MODEL_OUTPUT:

# split requests into smaller batches and parallelize them
futures = [
httpx_client().post(
aiohttp_client().post(
self._modelbit_url,
headers=self.headers,
json={"data": all_requests[i : i + settings.MODELBIT_BATCH_SIZE]},
Expand Down
Loading

0 comments on commit 0c4000e

Please sign in to comment.