Skip to content

Commit

Permalink
add index deletion, linting catastrophe.
Browse files Browse the repository at this point in the history
  • Loading branch information
bmschmidt committed Apr 19, 2024
1 parent ac7b9af commit a176554
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 254 deletions.
4 changes: 2 additions & 2 deletions nomic/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tqdm import tqdm

from .data_inference import NomicDuplicatesOptions, NomicEmbedOptions, NomicProjectOptions, NomicTopicOptions
from .dataset import AtlasDataStream, AtlasDataset
from .dataset import AtlasDataset, AtlasDataStream
from .settings import *
from .utils import arrow_iterator, b64int, get_random_name

Expand Down Expand Up @@ -61,7 +61,7 @@ def map_data(
project_name = get_random_name()

dataset_name = project_name
index_name=dataset_name
index_name = dataset_name

if identifier:
dataset_name = identifier
Expand Down
8 changes: 2 additions & 6 deletions nomic/aws/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def _get_sagemaker_role():
try:
return sagemaker.get_execution_role()
except ValueError:
raise ValueError(
"Unable to fetch sagemaker execution role. Please provide a role."
)
raise ValueError("Unable to fetch sagemaker execution role. Please provide a role.")


def parse_sagemaker_response(response):
Expand Down Expand Up @@ -157,9 +155,7 @@ def embed_texts(

for i in tqdm(range(0, len(texts), batch_size)):
batch = json.dumps({"texts": texts[i : i + batch_size]})
response = client.invoke_endpoint(
EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json"
)
response = client.invoke_endpoint(EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json")
embeddings.extend(parse_sagemaker_response(response))

return {
Expand Down
5 changes: 2 additions & 3 deletions nomic/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import jwt
import os
import time
from pathlib import Path

import click
import jwt
import requests
from rich.console import Console

Expand Down Expand Up @@ -63,7 +63,6 @@ def login(token, tenant='production', domain=None):
if not nomic_base_path.exists():
nomic_base_path.mkdir()


expires = None
refresh_token = None

Expand All @@ -85,7 +84,7 @@ def login(token, tenant='production', domain=None):
'refresh_token': refresh_token,
'token': bearer_token,
'tenant': tenant,
'expires': expires
'expires': expires,
}

if tenant == 'enterprise':
Expand Down
4 changes: 1 addition & 3 deletions nomic/data_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import pyarrow as pa
from pydantic import BaseModel, Field

from .settings import (
DEFAULT_DUPLICATE_THRESHOLD,
)
from .settings import DEFAULT_DUPLICATE_THRESHOLD


def from_list(values: Dict[str, Any], schema=None) -> pa.Table:
Expand Down
53 changes: 30 additions & 23 deletions nomic/data_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Dict, Iterable, Optional, List, Tuple
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas
Expand All @@ -36,12 +36,16 @@ def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.id_field = self.projection.dataset.id_field
try:
duplicate_fields = [field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field]
duplicate_fields = [
field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field
]
cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field]
assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map."
self.duplicate_field = duplicate_fields[0]
self.cluster_field = cluster_fields[0]
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, self.duplicate_field, self.cluster_field])
self._tb: pa.Table = projection._fetch_tiles().select(
[self.id_field, self.duplicate_field, self.cluster_field]
)
except pa.lib.ArrowInvalid as e:
raise ValueError("Duplicate detection has not yet been run on this map.")
self.duplicate_field = self.duplicate_field.lstrip("_")
Expand Down Expand Up @@ -75,7 +79,9 @@ def deletion_candidates(self) -> List[str]:

def __repr__(self) -> str:
repr = f"===Atlas Duplicates for ({self.projection})\n"
duplicate_count = len(self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')))
duplicate_count = len(
self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate'))
)
cluster_count = len(self.tb[self.cluster_field].value_counts())
repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n"
return repr + self.df.__repr__()
Expand Down Expand Up @@ -453,7 +459,7 @@ def _download_latent(self):
route = self.projection.dataset.atlas_api_path + '/v1/project/data/get/embedding/paged'
last = None

with tqdm(total=self.dataset.total_datums//limit) as pbar:
with tqdm(total=self.dataset.total_datums // limit) as pbar:
while True:
params = {'projection_id': self.projection.id, "last_file": last, "page_size": limit}
r = requests.post(route, headers=self.projection.dataset.header, json=params)
Expand Down Expand Up @@ -554,7 +560,6 @@ def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]:

raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.")


def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool:
'''
Deprecated in favor of `map.embeddings.latent`.
Expand All @@ -570,7 +575,6 @@ def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bo
'''
raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.")


def __repr__(self) -> str:
return str(self.df)

Expand All @@ -590,7 +594,7 @@ def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] =
self.auto_cleanup = auto_cleanup

@property
def df(self, overwrite: Optional[bool]=False) -> pd.DataFrame:
def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame:
'''
Pandas DataFrame mapping each data point to its tags.
'''
Expand Down Expand Up @@ -623,7 +627,7 @@ def df(self, overwrite: Optional[bool]=False) -> pd.DataFrame:
tb = tb.append_column(tag["tag_name"], bitmask)
tbs.append(tb)
return pa.concat_tables(tbs).to_pandas()

def get_tags(self) -> Dict[str, List[str]]:
'''
Retrieves back all tags made in the web browser for a specific map.
Expand All @@ -632,23 +636,26 @@ def get_tags(self) -> Dict[str, List[str]]:
Returns:
A list of tags a user has created for projection.
'''
tags = requests.get(self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all',
headers=self.dataset.header,
params={'project_id': self.dataset.id,
'projection_id': self.projection.id,
'include_dsl_rule': False}).json()
tags = requests.get(
self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all',
headers=self.dataset.header,
params={'project_id': self.dataset.id, 'projection_id': self.projection.id, 'include_dsl_rule': False},
).json()
keep_tags = []
for tag in tags:
is_complete = requests.get(self.dataset.atlas_api_path + '/v1/project/projection/tags/status',
is_complete = requests.get(
self.dataset.atlas_api_path + '/v1/project/projection/tags/status',
headers=self.dataset.header,
params={'project_id': self.dataset.id,
'tag_id': tag["tag_id"],
}).json()['is_complete']
params={
'project_id': self.dataset.id,
'tag_id': tag["tag_id"],
},
).json()['is_complete']
if is_complete:
keep_tags.append(tag)
return keep_tags
def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool]=False):

def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False):
'''
Returns the datum ids in a given tag.
Expand Down Expand Up @@ -687,7 +694,7 @@ def _get_tag_by_name(self, name: str) -> Dict:
if tag["tag_name"] == name:
return tag
raise ValueError(f"Tag {name} not found in projection {self.projection.id}.")

def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False):
"""
Downloads the feather tree for large sidecar columns.
Expand Down Expand Up @@ -715,12 +722,12 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False):
download_success = True
except pa.ArrowInvalid:
path.unlink(missing_ok=True)

if not download_success:
raise Exception(f"Failed to download tag {tag_name}.")
ordered_tag_paths.append(path)
return ordered_tag_paths

def _remove_outdated_tag_files(self, tag_definition_ids: List[str]):
'''
Attempts to remove outdated tag files based on tag definition ids.
Expand Down
Loading

0 comments on commit a176554

Please sign in to comment.