Skip to content

Commit

Permalink
support metadata on context upload (#14)
Browse files Browse the repository at this point in the history
* support metadata on context upload. need testing

* cleaner

* tested e2e
  • Loading branch information
Ben-Epstein authored Sep 28, 2023
1 parent b460eb8 commit e372d10
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 41 deletions.
4 changes: 2 additions & 2 deletions arcee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from arcee import config
from arcee.api import get_dalm, get_dalm_status, train_dalm, upload_doc, upload_docs
from arcee.dalm import DALM
from arcee.dalm import DALM, DALMFilter

if not config.ARCEE_API_KEY:
# We check this because it's impossible the user imported arcee, _then_ set the env, then imported again
Expand All @@ -13,4 +13,4 @@
config.ARCEE_API_KEY = input("ARCEE_API_KEY not found in environment. Please input api key: ")
os.environ["ARCEE_API_KEY"] = config.ARCEE_API_KEY

__all__ = ["upload_docs", "upload_doc", "train_dalm", "get_dalm", "DALM", "get_dalm_status"]
__all__ = ["upload_docs", "upload_doc", "train_dalm", "get_dalm", "DALM", "get_dalm_status", "DALMFilter"]
14 changes: 11 additions & 3 deletions arcee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
from arcee.schemas.routes import Route


def upload_doc(context: str, doc_name: str, doc_text: str) -> dict[str, str]:
def upload_doc(context: str, doc_name: str, doc_text: str, **kwargs: dict[str, int | float | str]) -> dict[str, str]:
"""
Upload a document to a context
Args:
context (str): The name of the context to upload to
doc_name (str): The name of the document
doc_text (str): The text of the document
kwargs: Any other key:value pairs to be included as extra metadata along with your doc
"""
doc = {"name": doc_name, "document": doc_text}
doc = {"name": doc_name, "document": doc_text, "meta": kwargs}
data = {"context_name": context, "documents": [doc]}
return make_request("post", Route.contexts, data)

Expand All @@ -27,13 +28,20 @@ def upload_docs(context: str, docs: list[dict[str, str]]) -> dict[str, str]:
Args:
context (str): The name of the context to upload to
docs (list): A list of dictionaries with keys "doc_name" and "doc_text"
Any other keys in the `docs` will be assumed as metadata, and will be uploaded as such. This metadata can
be filtered on during retrieval and generation.
"""
doc_list = []
for doc in docs:
if "doc_name" not in doc.keys() or "doc_text" not in doc.keys():
raise Exception("Each document must have a doc_name and doc_text key")

doc_list.append({"name": doc["doc_name"], "document": doc["doc_text"]})
new_doc: dict[str, str | dict] = {"name": doc.pop("doc_name"), "document": doc.pop("doc_text")}
# Any other keys are metadata
if doc:
new_doc["meta"] = doc
doc_list.append(new_doc)

data = {"context_name": context, "documents": doc_list}
return make_request("post", Route.contexts, data)
Expand Down
3 changes: 3 additions & 0 deletions arcee/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def context(
"""Upload document(s) to context. If a directory is provided, all valid files in the directory will be uploaded.
At least one of file or directory must be provided.
If you are using CSV or jsonl file(s), every key/column in your dataset that isn't that of `doc_name` and `doc_text`
will be uploaded as extra metadata fields with your doc. These can be used for filtering on generation and retrieval
Args:
name (str): Name of the context
file (Path): Path to the file.
Expand Down
2 changes: 1 addition & 1 deletion arcee/cli_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _get_docs(cls, file: Path, doc_name: str, doc_text: str) -> list[Doc]:
f"{doc_text} not found in data column/key. Rename column/key or use "
f"--doc-text in comment to specify your own"
)
return [Doc(doc_name=row[doc_name], doc_text=row[doc_text]) for _, row in df.iterrows()]
return [Doc(doc_name=row.pop(doc_name), doc_text=row.pop(doc_text), meta=dict(row)) for _, row in df.iterrows()]

@classmethod
def _handle_upload(
Expand Down
6 changes: 0 additions & 6 deletions arcee/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,4 @@ def get_conditional_configuration_variable(key: str, default: str) -> str:
ARCEE_API_URL = get_conditional_configuration_variable("ARCEE_API_URL", "https://api.arcee.ai")
ARCEE_APP_URL = get_conditional_configuration_variable("ARCEE_APP_URL", "https://app.arcee.ai")
ARCEE_API_KEY = get_conditional_configuration_variable("ARCEE_API_KEY", "")
ARCEE_RETRIEVAL_URL = get_conditional_configuration_variable(
"ARCEE_QUERY_URL", "https://3fjzbjz9ne.execute-api.us-east-2.amazonaws.com/prod/retrieve"
)
ARCEE_GENERATION_URL = get_conditional_configuration_variable(
"ARCEE_GENERATION_URL", "https://3fjzbjz9ne.execute-api.us-east-2.amazonaws.com/prod/generate"
)
ARCEE_API_VERSION = get_conditional_configuration_variable("ARCEE_API_VERSION", "v2")
96 changes: 69 additions & 27 deletions arcee/dalm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Literal
from enum import Enum
from typing import Any, Literal, Optional

import requests
from pydantic import BaseModel, model_validator

from arcee import config
from arcee.api_handler import make_request, retry_call
from arcee.api_handler import make_request
from arcee.schemas.routes import Route


Expand All @@ -12,6 +12,39 @@ def check_model_status(name: str) -> dict[str, str]:
return make_request("get", route)


class FilterType(str, Enum):
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"


class DALMFilter(BaseModel):
"""Filters available for a dalm retrieve/generation query
Arguments:
field_name: The field to filter on. Can be 'document' or 'name' to filter on your document's raw text or title
Any other field will be presumed to be a metadata field you included when uploading your context data
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported. More to come soon!
'fuzzy_search' means a fuzzy search on the provided field will be performed. The exact strict doesn't
need to exist in the document for this to find a match. Very useful for scanning a document for some
keyword terms
'strict_search' means that the exact string must appear in the provided field. This is NOT an exact eq
filter. ie a document with content "the happy dog crossed the street" will match on a strict_search of "dog"
but won't match on "the dog". Python equivalent of `return search_string in full_string`
value: The actual value to search for in the context data/metadata
"""

field_name: str
filter_type: FilterType
value: str
_is_metadata: bool = False

@model_validator(mode="after")
def set_meta(self) -> "DALMFilter":
"""document and name are reserved arcee keys. Anything else is metadata"""
self._is_metadata = self.field_name not in ["document", "name"]
return self


class DALM:
def __init__(self, name: str) -> None:
self.name = name
Expand All @@ -24,26 +57,35 @@ def __init__(self, name: str) -> None:
if self.status != "training_complete":
raise Exception("DALM model is not ready. Please wait for training to complete.")

# if ever separate retriever services froma arcee
# self.retriever_url = retriever_api_response["retriever_url"]
self.generate_url = config.ARCEE_GENERATION_URL
self.retriever_url = config.ARCEE_RETRIEVAL_URL

@retry_call(wait_sec=0.5)
def invoke(self, invocation_type: Literal["retrieve", "generate"], query: str, size: int) -> dict[str, Any]:
url = self.retriever_url if invocation_type == "retrieve" else self.generate_url
payload = {"model_id": self.model_id, "query": query, "size": size}
headers = {"Authorization": f"Bearer {config.ARCEE_API_KEY}"}

response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
raise Exception(f"Failed to {invocation_type}. Response: {response.text}")
return response.json()

def retrieve(self, query: str, size: int = 3) -> dict:
"""Retrieve {size} contexts with your retriever for the given query"""
return self.invoke("retrieve", query, size)

def generate(self, query: str, size: int = 3) -> dict:
"""Generate a response using {size} contexts with your generator for the given query"""
return self.invoke("generate", query, size)
def invoke(
self, invocation_type: Literal["retrieve", "generate"], query: str, size: int, filters: list[dict]
) -> dict[str, Any]:
route = Route.retrieve if invocation_type == "retrieve" else Route.generate
payload = {"model_id": self.model_id, "query": query, "size": size, "filters": filters, "id": self.model_id}
return make_request("post", route, body=payload)

def retrieve(self, query: str, size: int = 3, filters: Optional[list[DALMFilter]] = None) -> dict:
"""Retrieve {size} contexts with your retriever for the given query
Arguments:
query: The question to submit to the model
size: The max number of context results to retrieve (can be less if filters are provided)
filters: Optional filters to include with the query. This will restrict which context data the model can
retrieve from the context dataset
"""
filters = filters or []
ret_filters = [DALMFilter.model_validate(f).model_dump() for f in filters]
return self.invoke("retrieve", query, size, ret_filters)

def generate(self, query: str, size: int = 3, filters: Optional[list[DALMFilter]] = None) -> dict:
"""Generate a response using {size} contexts with your generator for the given query
Arguments:
query: The question to submit to the model
size: The max number of context results to retrieve (can be less if filters are provided)
filters: Optional filters to include with the query. This will restrict which context data the model can
retrieve from the context dataset
"""
filters = filters or []
gen_filters = [DALMFilter.model_validate(f).model_dump() for f in filters]
return self.invoke("generate", query, size, gen_filters)
5 changes: 3 additions & 2 deletions arcee/schemas/doc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import asdict, dataclass
from dataclasses import dataclass


@dataclass
class Doc:
doc_name: str
doc_text: str
meta: dict | None = None

def dict(self) -> dict[str, str]:
return asdict(self)
return {"doc_name": self.doc_name, "doc_text": self.doc_text, **(self.meta or {})}
2 changes: 2 additions & 0 deletions arcee/schemas/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ class Route(str, Enum):
train_retriever = "retrievers/train"
train_retriever_status = "retrievers/status/{id_or_name}"
identity = "whoami"
retrieve = "models/retrieve"
generate = "models/generate"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"requests",
"typer",
"rich",
"pydantic"
]

[project.scripts]
Expand Down Expand Up @@ -88,6 +89,7 @@ ban-relative-imports = "all"

[tool.mypy]
disallow_untyped_defs = true
plugins = "pydantic.mypy"

[[tool.mypy.overrides]]
module = "retrying.*"
Expand Down

0 comments on commit e372d10

Please sign in to comment.