Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into index-deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
bmschmidt committed Apr 19, 2024
2 parents 2d73ba5 + 7d59e7a commit ac7b9af
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 152 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ commands:
jobs:
build_test:
machine:
image: ubuntu-2004:202010-01 # recommended linux image
image: default
resource_class: large
steps:
- prep
- test

build_test_deploy:
machine:
image: ubuntu-2004:202010-01 # recommended linux image
image: default
resource_class: large
steps:
- prep
Expand Down Expand Up @@ -130,4 +130,4 @@ workflows:
# filters:
# branches:
# only:
# - master
# - master
76 changes: 38 additions & 38 deletions examples/sagemaker/run-nomic-embed-text.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install nomic"
"!pip install nomic\n",
"!pip install numpy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 12,
"id": "cc3b38b6-34ef-48bd-923d-938b88471873",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from nomic.aws.sagemaker import embed_texts"
]
},
Expand All @@ -40,12 +43,12 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "cf4e8007-3fa2-473b-8748-7aa4de26cb2f",
"metadata": {},
"outputs": [],
"source": [
"endpoint_name = 'triton-nomic-embed-text-v1-5-test-2024-03-22-17-44-47'\n",
"endpoint_name = 'nomic-embed-endpoint'\n",
"region_name = 'us-east-2'"
]
},
Expand All @@ -61,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "9bd45a53-de0a-4a63-af40-731c218b4ea4",
"metadata": {},
"outputs": [],
Expand All @@ -78,65 +81,62 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "a668e03f-9a50-4b3d-9a02-a059136af6b9",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.93it/s]\n"
]
}
],
"source": [
"embeddings = embed_texts(texts, endpoint_name, region_name=region_name, batch_size=32)"
"response = embed_texts(texts, endpoint_name, region_name=region_name, batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "3b704bd6-5951-40c3-ba26-dfdedb4a7ff1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6, 768)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"embeddings.shape"
"embeddings = response[\"embeddings\"]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 13,
"id": "d0a81001-a832-4b27-b9ec-65ec68b133e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.04666 , 0.02484 , -0.1688 , ..., -0.04614 , -0.01462 ,\n",
" -0.01997 ],\n",
" [ 0.04868 , 0.0272 , -0.1686 , ..., -0.04138 , -0.03778 ,\n",
" -0.02217 ],\n",
" [ 0.05188 , 0.01855 , -0.1583 , ..., -0.0631 , -0.005123,\n",
" -0.01807 ],\n",
" [-0.01142 , 0.02853 , -0.1456 , ..., -0.03638 , -0.0711 ,\n",
" 0.005814],\n",
" [-0.01142 , 0.02853 , -0.1456 , ..., -0.03638 , -0.0711 ,\n",
" 0.005814],\n",
" [-0.01142 , 0.02853 , -0.1456 , ..., -0.03638 , -0.0711 ,\n",
" 0.005814]], dtype=float16)"
"array([[ 0.03738403, 0.00876617, -0.1116333 , ..., -0.04412842,\n",
" -0.04345703, -0.0524292 ],\n",
" [ 0.03637695, 0.01615906, -0.12445068, ..., -0.04266357,\n",
" -0.06054688, -0.05432129],\n",
" [ 0.05923462, 0.02310181, -0.1315918 , ..., -0.05889893,\n",
" -0.03872681, -0.04345703],\n",
" [-0.01360321, 0.04324341, -0.16638184, ..., -0.05523682,\n",
" -0.07879639, -0.00566101],\n",
" [-0.01360321, 0.04324341, -0.16638184, ..., -0.05523682,\n",
" -0.07879639, -0.00566101],\n",
" [-0.01360321, 0.04324341, -0.16638184, ..., -0.05523682,\n",
" -0.07879639, -0.00566101]])"
]
},
"execution_count": 6,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings"
"np.array(embeddings)"
]
},
{
Expand Down Expand Up @@ -164,7 +164,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ PYTHON:=python3

all: venv
source env/bin/activate; python -m pip install --upgrade pip
source env/bin/activate; pip install --use-deprecated=legacy-resolver -e .
source env/bin/activate; pip install -e .

venv:
if [ ! -d $(ROOT_DIR)/env ]; then $(PYTHON) -m venv $(ROOT_DIR)/env; fi

dev: all
source env/bin/activate; pip install --use-deprecated=legacy-resolver -e ".[dev, aws]"
source env/bin/activate; pip install -e ".[dev]"

black:
source env/bin/activate; black -l 120 -S --target-version py36 nomic
Expand All @@ -36,7 +36,7 @@ lint:
pretty: isort black

test:
source env/bin/activate; pytest -s nomic/tests
source env/bin/activate; pytest -s tests
clean:
rm -rf {.pytest_cache,env,nomic.egg-info}
find . | grep -E "(__pycache__|\.pyc|\.pyo$\)" | xargs rm -rf
find . | grep -E "(__pycache__|\.pyc|\.pyo$\)" | xargs rm -rf
70 changes: 39 additions & 31 deletions nomic/aws/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# NOTE: Currently Sagemaker only supports nomic-embed-text-v1.5 model.

SAGEMAKER_MODELS = {"nomic-embed-text-v1.5": {"us-east-2": "TODO: ARN"}}


def _get_supported_regions(model: str):
return SAGEMAKER_MODELS[model].keys()


def _get_model_and_region_for_arn(arn: str):
for model in SAGEMAKER_MODELS:
for region, arn in SAGEMAKER_MODELS[model].items():
if arn == arn:
return model, region
raise ValueError(f"Model package arn {arn} not supported.")


def _get_sagemaker_role():
try:
Expand All @@ -48,7 +32,27 @@ def parse_sagemaker_response(response):
"""
# Parse json header size length from the response
resp = json.loads(response["Body"].read().decode())
return np.array(resp["embeddings"], dtype=np.float16)
return resp["embeddings"]


def preprocess_texts(texts: List[str], task_type: str = "search_document"):
"""
Preprocess a list of texts for embedding using a sagemaker model.
Args:
texts: List of texts to be embedded.
task_type: The task type to use when embedding. One of `search_query`, `search_document`, `classification`, `clustering`
Returns:
List of texts formatted for sagemaker embedding.
"""
assert task_type in [
"search_query",
"search_document",
"classification",
"clustering",
], f"Invalid task type: {task_type}"
return [f"{task_type}: {text}" for text in texts]


def batch_transform(
Expand Down Expand Up @@ -83,15 +87,6 @@ def batch_transform(
"""
if arn is None:
raise ValueError("model package arn is currently required.")
if region_name is None or model_name is None:
raise ValueError(
"region_name and model_name is required if arn is not provided."
)
if region_name not in _get_supported_regions(model_name):
raise ValueError(
f"Model {model_name} not supported in region {region_name}."
)
arn = SAGEMAKER_MODELS[model_name][region_name]

if role is None:
logger.info("No role provided. Using default sagemaker role.")
Expand Down Expand Up @@ -131,7 +126,11 @@ def batch_transform(


def embed_texts(
texts: List[str], sagemaker_endpoint: str, region_name: str, batch_size=32
texts: List[str],
sagemaker_endpoint: str,
region_name: str,
task_type: str = "search_document",
batch_size: int = 32,
):
"""
Embed a list of texts using a sagemaker model endpoint.
Expand All @@ -140,15 +139,19 @@ def embed_texts(
texts: List of texts to be embedded.
sagemaker_endpoint: The sagemaker endpoint to use.
region_name: AWS region sagemaker endpoint is in.
batch_size: Size of each batch.
task_type: The task type to use when embedding.
batch_size: Size of each batch. Default is 32.
Returns:
np.float16 array of embeddings.
Dictionary with "embeddings" (python 2d list of floats), "model" (sagemaker endpoint used to generate embeddings).
"""

if len(texts) == 0:
logger.warning("No texts to embed.")
return None

texts = preprocess_texts(texts, task_type)

client = boto3.client("sagemaker-runtime", region_name=region_name)
embeddings = []

Expand All @@ -157,5 +160,10 @@ def embed_texts(
response = client.invoke_endpoint(
EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json"
)
embeddings.append(parse_sagemaker_response(response))
return np.vstack(embeddings)
embeddings.extend(parse_sagemaker_response(response))

return {
"embeddings": embeddings,
"model": "nomic-embed-text-v1.5",
"usage": {},
}
16 changes: 12 additions & 4 deletions nomic/data_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,11 @@ def _read_prefetched_tiles_with_sidecars(self, additional_sidecars=None):
for col in carfile.column_names:
tb = tb.append_column(col, carfile[col])
for big_sidecar in additional_sidecars:
fname = base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
fname = (
base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8")
if big_sidecar != 'datum_id'
else big_sidecar
)
carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True)
for col in carfile.column_names:
tb = tb.append_column(col, carfile[col])
Expand All @@ -835,6 +839,7 @@ def _download_data(self, fields=None):

all_quads = list(self.projection._tiles_in_order(coords_only=True))
sidecars = fields
registered_sidecars = self.projection._registered_sidecars()
if sidecars is None:
sidecars = [
field
Expand All @@ -844,18 +849,21 @@ def _download_data(self, fields=None):
else:
for field in sidecars:
assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields."
encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars]
if any(sidecar == 'datum_id' for (field, sidecar) in registered_sidecars):
sidecars.append('datum_id')
encoded_sidecars.append('datum_id')

for quad in tqdm(all_quads):
for sidecar in sidecars:
for encoded_colname in encoded_sidecars:
quad_str = os.path.join(*[str(q) for q in quad])
encoded_colname = base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8")
filename = quad_str + "." + encoded_colname + ".feather"
path = self.projection.tile_destination / Path(filename)

if not os.path.exists(path):
# WARNING: Potentially large data request here
download_feather(root + filename, path, headers=self.dataset.header)

return sidecars

@property
Expand Down
Loading

0 comments on commit ac7b9af

Please sign in to comment.