Skip to content

Commit

Permalink
feat(table-wrapper): add default value for table models
Browse files Browse the repository at this point in the history
  • Loading branch information
Yax94 authored and matbleu committed Feb 6, 2023
1 parent fba6434 commit fc31acf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Pythie serving

## 3.2.0
## Change
* Add a default value from metadata for table models

## 3.1.0
## Change
* Add GRPC server timeout (in seconds) in var env `GRPC_SERVER_TIMEOUT`. Default to None.
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.1.0
3.2.0
40 changes: 29 additions & 11 deletions src/pythie_serving/table_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import json
from typing import Any

import numpy as np
from numpy.typing import NDArray
Expand All @@ -19,6 +20,18 @@
class TablePredictionServiceServicer(AbstractPythieServingPredictionServiceServicer):
model_file_extension = ".csv"

def _parse_extra_specs(self, metadata: dict[str, Any], table_type_mapping: dict[str, Any]) -> dict[str, Any] | None:
extra_specs = None
if "default_value" in metadata:
default_value = metadata["default_value"]
if type(default_value) != table_type_mapping[metadata["target_name"]]:
raise PythieServingException(
f"Can't assign default_value {default_value} because it's type {type(default_value)} is "
f"different from target type {table_type_mapping[metadata['target_name']]}"
)
extra_specs = {"default_value": default_value}
return extra_specs

def _create_model_specs(self, model_config: ModelConfig) -> ModelSpecs:

with open(self._get_metadata_path(model_config)) as f:
Expand All @@ -44,13 +57,15 @@ def _create_model_specs(self, model_config: ModelConfig) -> ModelSpecs:
value = table_type_mapping[metadata["target_name"]](row[metadata["target_name"]])
table[key] = value

return {
"model": table,
"feature_names": metadata["feature_names"],
"nb_features": len(metadata["feature_names"]),
"samples_dtype": object,
"extra_specs": None,
}
extra_specs = self._parse_extra_specs(metadata, table_type_mapping)

return ModelSpecs(
model=table,
feature_names=metadata["feature_names"],
nb_features=len(metadata["feature_names"]),
samples_dtype=object,
extra_specs=extra_specs,
)

def _predict(self, model_specs: ModelSpecs, samples: NDArray) -> NDArray:

Expand All @@ -59,10 +74,13 @@ def _predict(self, model_specs: ModelSpecs, samples: NDArray) -> NDArray:
try:
pred = model_specs["model"][tuple(feature_value for feature_value in sample)]
except KeyError:
raise PythieServingException(
f"No prediction found in table for given features: " f"{model_specs['feature_names']} = {sample}."
)
if model_specs["extra_specs"] is not None and "default_value" in model_specs["extra_specs"]:
pred = model_specs["extra_specs"]["default_value"]
else:
raise PythieServingException(
f"No prediction found in table for given features: "
f"{model_specs['feature_names']} = {sample}."
)

output[idx] = pred

return output

0 comments on commit fc31acf

Please sign in to comment.