diff --git a/nomic/atlas.py b/nomic/atlas.py index 6b2632d9..771d0e8d 100644 --- a/nomic/atlas.py +++ b/nomic/atlas.py @@ -13,7 +13,7 @@ from tqdm import tqdm from .data_inference import NomicDuplicatesOptions, NomicEmbedOptions, NomicProjectOptions, NomicTopicOptions -from .dataset import AtlasDataStream, AtlasDataset +from .dataset import AtlasDataset, AtlasDataStream from .settings import * from .utils import arrow_iterator, b64int, get_random_name @@ -61,7 +61,7 @@ def map_data( project_name = get_random_name() dataset_name = project_name - index_name=dataset_name + index_name = dataset_name if identifier: dataset_name = identifier diff --git a/nomic/aws/sagemaker.py b/nomic/aws/sagemaker.py index 9b3b958d..6b4df253 100644 --- a/nomic/aws/sagemaker.py +++ b/nomic/aws/sagemaker.py @@ -15,9 +15,7 @@ def _get_sagemaker_role(): try: return sagemaker.get_execution_role() except ValueError: - raise ValueError( - "Unable to fetch sagemaker execution role. Please provide a role." - ) + raise ValueError("Unable to fetch sagemaker execution role. Please provide a role.") def parse_sagemaker_response(response): @@ -157,9 +155,7 @@ def embed_texts( for i in tqdm(range(0, len(texts), batch_size)): batch = json.dumps({"texts": texts[i : i + batch_size]}) - response = client.invoke_endpoint( - EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json" - ) + response = client.invoke_endpoint(EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json") embeddings.extend(parse_sagemaker_response(response)) return { diff --git a/nomic/cli.py b/nomic/cli.py index 31392334..5d1300c8 100644 --- a/nomic/cli.py +++ b/nomic/cli.py @@ -1,10 +1,10 @@ import json -import jwt import os import time from pathlib import Path import click +import jwt import requests from rich.console import Console @@ -63,7 +63,6 @@ def login(token, tenant='production', domain=None): if not nomic_base_path.exists(): nomic_base_path.mkdir() - expires = None refresh_token = None @@ -85,7 +84,7 @@ def login(token, tenant='production', domain=None): 'refresh_token': refresh_token, 'token': bearer_token, 'tenant': tenant, - 'expires': expires + 'expires': expires, } if tenant == 'enterprise': diff --git a/nomic/data_inference.py b/nomic/data_inference.py index 3f2a47ac..80a689de 100644 --- a/nomic/data_inference.py +++ b/nomic/data_inference.py @@ -3,9 +3,7 @@ import pyarrow as pa from pydantic import BaseModel, Field -from .settings import ( - DEFAULT_DUPLICATE_THRESHOLD, -) +from .settings import DEFAULT_DUPLICATE_THRESHOLD def from_list(values: Dict[str, Any], schema=None) -> pa.Table: diff --git a/nomic/data_operations.py b/nomic/data_operations.py index 15680356..75ac4ffa 100644 --- a/nomic/data_operations.py +++ b/nomic/data_operations.py @@ -9,7 +9,7 @@ from datetime import datetime from io import BytesIO from pathlib import Path -from typing import Dict, Iterable, Optional, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import numpy as np import pandas @@ -36,12 +36,16 @@ def __init__(self, projection: "AtlasProjection"): self.projection = projection self.id_field = self.projection.dataset.id_field try: - duplicate_fields = [field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field] + duplicate_fields = [ + field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field + ] cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field] assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map." self.duplicate_field = duplicate_fields[0] self.cluster_field = cluster_fields[0] - self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, self.duplicate_field, self.cluster_field]) + self._tb: pa.Table = projection._fetch_tiles().select( + [self.id_field, self.duplicate_field, self.cluster_field] + ) except pa.lib.ArrowInvalid as e: raise ValueError("Duplicate detection has not yet been run on this map.") self.duplicate_field = self.duplicate_field.lstrip("_") @@ -75,7 +79,9 @@ def deletion_candidates(self) -> List[str]: def __repr__(self) -> str: repr = f"===Atlas Duplicates for ({self.projection})\n" - duplicate_count = len(self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate'))) + duplicate_count = len( + self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')) + ) cluster_count = len(self.tb[self.cluster_field].value_counts()) repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n" return repr + self.df.__repr__() @@ -453,7 +459,7 @@ def _download_latent(self): route = self.projection.dataset.atlas_api_path + '/v1/project/data/get/embedding/paged' last = None - with tqdm(total=self.dataset.total_datums//limit) as pbar: + with tqdm(total=self.dataset.total_datums // limit) as pbar: while True: params = {'projection_id': self.projection.id, "last_file": last, "page_size": limit} r = requests.post(route, headers=self.projection.dataset.header, json=params) @@ -554,7 +560,6 @@ def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]: raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.") - def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool: ''' Deprecated in favor of `map.embeddings.latent`. @@ -570,7 +575,6 @@ def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bo ''' raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.") - def __repr__(self) -> str: return str(self.df) @@ -590,7 +594,7 @@ def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = self.auto_cleanup = auto_cleanup @property - def df(self, overwrite: Optional[bool]=False) -> pd.DataFrame: + def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: ''' Pandas DataFrame mapping each data point to its tags. ''' @@ -623,7 +627,7 @@ def df(self, overwrite: Optional[bool]=False) -> pd.DataFrame: tb = tb.append_column(tag["tag_name"], bitmask) tbs.append(tb) return pa.concat_tables(tbs).to_pandas() - + def get_tags(self) -> Dict[str, List[str]]: ''' Retrieves back all tags made in the web browser for a specific map. @@ -632,23 +636,26 @@ def get_tags(self) -> Dict[str, List[str]]: Returns: A list of tags a user has created for projection. ''' - tags = requests.get(self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all', - headers=self.dataset.header, - params={'project_id': self.dataset.id, - 'projection_id': self.projection.id, - 'include_dsl_rule': False}).json() + tags = requests.get( + self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all', + headers=self.dataset.header, + params={'project_id': self.dataset.id, 'projection_id': self.projection.id, 'include_dsl_rule': False}, + ).json() keep_tags = [] for tag in tags: - is_complete = requests.get(self.dataset.atlas_api_path + '/v1/project/projection/tags/status', + is_complete = requests.get( + self.dataset.atlas_api_path + '/v1/project/projection/tags/status', headers=self.dataset.header, - params={'project_id': self.dataset.id, - 'tag_id': tag["tag_id"], - }).json()['is_complete'] + params={ + 'project_id': self.dataset.id, + 'tag_id': tag["tag_id"], + }, + ).json()['is_complete'] if is_complete: keep_tags.append(tag) return keep_tags - - def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool]=False): + + def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): ''' Returns the datum ids in a given tag. @@ -687,7 +694,7 @@ def _get_tag_by_name(self, name: str) -> Dict: if tag["tag_name"] == name: return tag raise ValueError(f"Tag {name} not found in projection {self.projection.id}.") - + def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): """ Downloads the feather tree for large sidecar columns. @@ -715,12 +722,12 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): download_success = True except pa.ArrowInvalid: path.unlink(missing_ok=True) - + if not download_success: raise Exception(f"Failed to download tag {tag_name}.") ordered_tag_paths.append(path) return ordered_tag_paths - + def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): ''' Attempts to remove outdated tag files based on tag definition ids. diff --git a/nomic/dataset.py b/nomic/dataset.py index 4b57d998..8eb3e890 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -34,13 +34,7 @@ NomicTopicOptions, convert_pyarrow_schema_for_atlas, ) -from .data_operations import ( - AtlasMapData, - AtlasMapDuplicates, - AtlasMapEmbeddings, - AtlasMapTags, - AtlasMapTopics, -) +from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics from .settings import * from .utils import assert_valid_project_id, get_object_size_in_bytes @@ -113,9 +107,7 @@ def _get_current_user(self): ) response = validate_api_http_response(response) if not response.status_code == 200: - raise ValueError( - "Your authorization token is no longer valid. Run `nomic login` to obtain a new one." - ) + raise ValueError("Your authorization token is no longer valid. Run `nomic login` to obtain a new one.") return response.json() @@ -130,9 +122,7 @@ def _validate_map_data_inputs(self, colorable_fields, id_field, data_sample): for field in colorable_fields: if field not in data_sample: - raise Exception( - f"Cannot color by field `{field}` as it is not present in the metadata." - ) + raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.") def _get_current_users_main_organization(self): """ @@ -149,10 +139,7 @@ def _get_current_users_main_organization(self): return organization for organization in user["organizations"]: - if ( - organization["user_id"] == user["sub"] - and organization["access_role"] == "OWNER" - ): + if organization["user_id"] == user["sub"] and organization["access_role"] == "OWNER": return organization def _delete_project_by_id(self, project_id): @@ -180,9 +167,7 @@ def _get_project_by_id(self, project_id: str): ) if response.status_code != 200: - raise Exception( - f"Could not access dataset with id {project_id}: {response.text}" - ) + raise Exception(f"Could not access dataset with id {project_id}: {response.text}") return response.json() @@ -268,9 +253,7 @@ def _get_index_job(self, job_id: str): return response.json() - def _validate_and_correct_arrow_upload( - self, data: pa.Table, project: "AtlasDataset" - ) -> pa.Table: + def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasDataset") -> pa.Table: """ Private method. validates upload data against the dataset arrow schema, and associated other checks. @@ -302,8 +285,7 @@ def _validate_and_correct_arrow_upload( for col in data.column_names: if col.lower() in seen: raise ValueError( - f"Two different fields have the same lowercased name, `{col}`" - ": you must use unique column names." + f"Two different fields have the same lowercased name, `{col}`" ": you must use unique column names." ) seen.add(col.lower()) @@ -333,41 +315,27 @@ def _validate_and_correct_arrow_upload( logger.warning( f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version." ) - reformatted[field.name] = pc.fill_null( - reformatted[field.name], "null" - ) + reformatted[field.name] = pc.fill_null(reformatted[field.name], "null") if pc.any(pc.equal(pc.binary_length(reformatted[field.name]), 0)): - mask = pc.equal( - pc.binary_length(reformatted[field.name]), 0 - ).combine_chunks() + mask = pc.equal(pc.binary_length(reformatted[field.name]), 0).combine_chunks() assert pa.types.is_boolean(mask.type) - reformatted[field.name] = pc.replace_with_mask( - reformatted[field.name], mask, "null" - ) + reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") for field in data.schema: if not field.name in reformatted: if field.name == "_embeddings": reformatted["_embeddings"] = data["_embeddings"] else: - logger.warning( - f"Field {field.name} present in data, but not found in table schema. Ignoring." - ) + logger.warning(f"Field {field.name} present in data, but not found in table schema. Ignoring.") data = pa.Table.from_pydict(reformatted, schema=project.schema) if project.meta["insert_update_delete_lock"]: - raise Exception( - "Project is currently indexing and cannot ingest new datums. Try again later." - ) + raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.") # The following two conditions should never occur given the above, but just in case... - assert ( - project.id_field in data.column_names - ), f"Upload does not contain your specified id_field" + assert project.id_field in data.column_names, f"Upload does not contain your specified id_field" if not pa.types.is_string(data[project.id_field].type): - logger.warning( - f"id_field is not a string. Converting to string from {data[project.id_field].type}" - ) + logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}") data = data.drop([project.id_field]).append_column( project.id_field, data[project.id_field].cast(pa.string()) ) @@ -378,17 +346,15 @@ def _validate_and_correct_arrow_upload( continue raise ValueError("Metadata fields cannot start with _") if pc.max(pc.utf8_length(data[project.id_field])).as_py() > 36: - first_match = data.filter( - pc.greater(pc.utf8_length(data[project.id_field]), 36) - ).to_pylist()[0][project.id_field] + first_match = data.filter(pc.greater(pc.utf8_length(data[project.id_field]), 36)).to_pylist()[0][ + project.id_field + ] raise ValueError( f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters." ) return data - def _get_organization( - self, organization_slug=None, organization_id=None - ) -> Tuple[str, str]: + def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]: """ Gets an organization by either its name or id. @@ -402,16 +368,12 @@ def _get_organization( """ if organization_slug is None: - if ( - organization_id is None - ): # default to current users organization (the one with their name) + if organization_id is None: # default to current users organization (the one with their name) organization = self._get_current_users_main_organization() organization_slug = organization["slug"] organization_id = organization["organization_id"] else: - raise NotImplementedError( - "Getting organization by a specific ID is not yet implemented." - ) + raise NotImplementedError("Getting organization by a specific ID is not yet implemented.") else: try: @@ -434,9 +396,7 @@ class AtlasIndex: the points in the index that you can browse online. """ - def __init__( - self, atlas_index_id, name, indexed_field, projections, dataset: AtlasDataset - ): + def __init__(self, atlas_index_id, name, indexed_field, projections, dataset: AtlasDataset): """Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.""" self.id = atlas_index_id self.name = name @@ -504,8 +464,7 @@ def map_link(self): @property def _status(self): response = requests.get( - self.dataset.atlas_api_path - + f"/v1/project/index/job/progress/{self.atlas_index_id}", + self.dataset.atlas_api_path + f"/v1/project/index/job/progress/{self.atlas_index_id}", headers=self.dataset.header, ) if response.status_code != 200: @@ -582,9 +541,7 @@ def _repr_html_(self): def duplicates(self): """Duplicate detection state""" if self.dataset.is_locked: - raise Exception( - "Dataset is locked! Please wait until the dataset is unlocked to access duplicates." - ) + raise Exception("Dataset is locked! Please wait until the dataset is unlocked to access duplicates.") if self._duplicates is None: self._duplicates = AtlasMapDuplicates(self) return self._duplicates @@ -642,8 +599,7 @@ def schema(self): ) if self._schema is None: response = requests.get( - self.dataset.atlas_api_path - + f"/v1/project/projection/{self.projection_id}/schema", + self.dataset.atlas_api_path + f"/v1/project/projection/{self.projection_id}/schema", headers=self.dataset.header, ) if response.status_code != 200: @@ -676,24 +632,16 @@ def _fetch_tiles(self, overwrite: bool = True): return self._tile_data self._download_large_feather(overwrite=overwrite) tbs = [] - root = feather.read_table( - self.tile_destination / "0/0/0.feather", memory_map=True - ) + root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True) try: - sidecars = set( - [v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()] - ) + sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()]) except KeyError: sidecars = set([]) - sidecars |= set( - sidecar_name for (_, sidecar_name) in self._registered_sidecars() - ) + sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) for path in self._tiles_in_order(): tb = pa.feather.read_table(path, memory_map=True) for sidecar_file in sidecars: - carfile = pa.feather.read_table( - path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True - ) + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) @@ -722,9 +670,7 @@ def children(z, x, y): # Pop off the front, extend the back (breadth first traversal) while len(paths) > 0: z, x, y = paths.pop(0) - path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix( - ".feather" - ) + path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather") if path.exists(): if coords_only: yield (z, x, y) @@ -736,9 +682,7 @@ def children(z, x, y): def tile_destination(self): return Path("~/.nomic/cache", self.id).expanduser() - def _download_large_feather( - self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True - ): + def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True): """ Downloads the feather tree. Args: @@ -753,9 +697,7 @@ def _download_large_feather( root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/" all_quads = [] sidecars = None - registered_sidecars = set( - sidecar_name for (_, sidecar_name) in self._registered_sidecars() - ) + registered_sidecars = set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) while len(quads) > 0: rawquad = quads.pop(0) quad = rawquad + ".feather" @@ -785,9 +727,7 @@ def _download_large_feather( if sidecars is None and b"sidecars" in schema.metadata: # Grab just the filenames - sidecars = set( - [v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()] - ) + sidecars = set([v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()]) elif sidecars is None: sidecars = set() if not "." in rawquad: @@ -841,9 +781,7 @@ class AtlasDataStream(AtlasClass): def __init__(self, name: Optional[str] = "contrastors"): super().__init__() if name != "contrastors": - raise NotImplementedError( - "Only contrastors datastream is currently supported" - ) + raise NotImplementedError("Only contrastors datastream is currently supported") self.name = name # TODO: add support for other datastreams @@ -882,9 +820,7 @@ def __init__( * **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view. * **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists. """ - assert ( - identifier is not None or dataset_id is not None - ), "You must pass a dataset identifier" + assert identifier is not None or dataset_id is not None, "You must pass a dataset identifier" super().__init__() @@ -908,9 +844,7 @@ def __init__( dataset_id = dataset["id"] if dataset_id is None: # if there is no existing project, make a new one. - if ( - unique_id_field is None - ): # if not all parameters are specified, we weren't trying to make a project + if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project raise ValueError(f"Dataset `{identifier}` does not exist.") # if modality is None: @@ -936,9 +870,7 @@ def delete(self): organization = self._get_current_users_main_organization() organization_slug = organization["slug"] - logger.info( - f"Deleting dataset `{self.slug}` from organization `{organization_slug}`" - ) + logger.info(f"Deleting dataset `{self.slug}` from organization `{organization_slug}`") self._delete_project_by_id(project_id=self.id) @@ -1059,9 +991,7 @@ def id_field(self) -> str: @property def created_timestamp(self) -> datetime: - return datetime.strptime( - self.meta["created_timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z" - ) + return datetime.strptime(self.meta["created_timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z") @property def total_datums(self) -> int: @@ -1105,9 +1035,7 @@ def schema(self) -> Optional[pa.Schema]: if self._schema is not None: return self._schema if "schema" in self.meta and self.meta["schema"] is not None: - self._schema: pa.Schema = ipc.read_schema( - io.BytesIO(base64.b64decode(self.meta["schema"])) - ) + self._schema: pa.Schema = ipc.read_schema(io.BytesIO(base64.b64decode(self.meta["schema"]))) return self._schema return None @@ -1134,9 +1062,7 @@ def wait_for_dataset_lock(self): has_logged = True time.sleep(5) - def get_map( - self, name: str = None, atlas_index_id: str = None, projection_id: str = None - ) -> AtlasProjection: + def get_map(self, name: str = None, atlas_index_id: str = None, projection_id: str = None) -> AtlasProjection: """ Retrieves a map. @@ -1155,22 +1081,16 @@ def get_map( for index in indices: if index.id == atlas_index_id: if len(index.projections) == 0: - raise ValueError( - f"No map found under index with atlas_index_id='{atlas_index_id}'" - ) + raise ValueError(f"No map found under index with atlas_index_id='{atlas_index_id}'") return index.projections[0] - raise ValueError( - f"Could not find a map with atlas_index_id='{atlas_index_id}'" - ) + raise ValueError(f"Could not find a map with atlas_index_id='{atlas_index_id}'") if projection_id is not None: for index in indices: for projection in index.projections: if projection.id == projection_id: return projection - raise ValueError( - f"Could not find a map with projection_id='{atlas_index_id}'" - ) + raise ValueError(f"Could not find a map with projection_id='{atlas_index_id}'") if len(indices) == 0: raise ValueError("You have no maps built in your project") @@ -1253,9 +1173,7 @@ def create_index( colorable_fields = [] for field in self.dataset_fields: - if field not in [self.id_field, indexed_field] and not field.startswith( - "_" - ): + if field not in [self.id_field, indexed_field] and not field.startswith("_"): colorable_fields.append(field) if self.modality == "embedding": @@ -1272,9 +1190,7 @@ def create_index( "colorable_fields": colorable_fields, "model_hyperparameters": None, "nearest_neighbor_index": "HNSWIndex", - "nearest_neighbor_index_hyperparameters": json.dumps( - {"space": "l2", "ef_construction": 100, "M": 16} - ), + "nearest_neighbor_index_hyperparameters": json.dumps({"space": "l2", "ef_construction": 100, "M": 16}), "projection": "NomicProject", "projection_hyperparameters": json.dumps( { @@ -1317,14 +1233,10 @@ def create_index( ) if indexed_field is None: - raise Exception( - "You did not specify a field to index. Specify an 'indexed_field'." - ) + raise Exception("You did not specify a field to index. Specify an 'indexed_field'.") if indexed_field not in self.dataset_fields: - raise Exception( - f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}" - ) + raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}") build_template = { "project_id": self.id, @@ -1343,9 +1255,7 @@ def create_index( } ), "nearest_neighbor_index": "HNSWIndex", - "nearest_neighbor_index_hyperparameters": json.dumps( - {"space": "l2", "ef_construction": 100, "M": 16} - ), + "nearest_neighbor_index_hyperparameters": json.dumps({"space": "l2", "ef_construction": 100, "M": 16}), "projection": "NomicProject", "projection_hyperparameters": json.dumps( { @@ -1379,9 +1289,7 @@ def create_index( json=build_template, ) if response.status_code != 200: - logger.info( - "Create dataset failed with code: {}".format(response.status_code) - ) + logger.info("Create dataset failed with code: {}".format(response.status_code)) logger.info("Additional info: {}".format(response.text)) raise Exception(response.json()["detail"]) @@ -1406,9 +1314,7 @@ def create_index( if projection is None: logger.warning("Could not find a map being built for this dataset.") - logger.info( - f"Created map `{projection.name}` in dataset `{self.identifier}`: {projection.map_link}" - ) + logger.info(f"Created map `{projection.name}` in dataset `{self.identifier}`: {projection.map_link}") return projection def __repr__(self): @@ -1442,13 +1348,7 @@ def _repr_html_(self): return html def __str__(self): - return "\n".join( - [ - str(projection) - for index in self.indices - for projection in index.projections - ] - ) + return "\n".join([str(projection) for index in self.indices for projection in index.projections]) def get_data(self, ids: List[str]) -> List[Dict]: """ @@ -1465,9 +1365,7 @@ def get_data(self, ids: List[str]) -> List[Dict]: if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") if isinstance(ids[0], list): - raise ValueError( - "You must specify a list of ids when getting data, not a nested list." - ) + raise ValueError("You must specify a list of ids when getting data, not a nested list.") response = requests.post( self.atlas_api_path + "/v1/project/data/get", headers=self.header, @@ -1517,9 +1415,7 @@ def add_data( embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints. pbar: (Optional). A tqdm progress bar to update. """ - if embeddings is not None or ( - isinstance(data, pa.Table) and "_embeddings" in data.column_names - ): + if embeddings is not None or (isinstance(data, pa.Table) and "_embeddings" in data.column_names): self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) else: self._add_text(data=data, pbar=pbar) @@ -1535,9 +1431,7 @@ def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None): elif isinstance(data, list): data = pa.Table.from_pylist(data) elif not isinstance(data, pa.Table): - raise ValueError( - "Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table." - ) + raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.") self._add_data(data, pbar=pbar) def _add_embeddings( @@ -1561,9 +1455,7 @@ def _add_embeddings( """ assert type(embeddings) == np.ndarray, "Embeddings must be a NumPy array." assert len(embeddings.shape) == 2, "Embeddings must be a 2D NumPy array." - assert ( - len(data) == embeddings.shape[0] - ), "Data and embeddings must have the same number of rows." + assert len(data) == embeddings.shape[0], "Data and embeddings must have the same number of rows." assert len(data) > 0, "Data must have at least one row." tb: pa.Table @@ -1589,9 +1481,7 @@ def _add_embeddings( assert not np.isnan(embeddings).any(), "Embeddings must not contain NaN values." assert not np.isinf(embeddings).any(), "Embeddings must not contain Inf values." - pyarrow_embeddings = pa.FixedSizeListArray.from_arrays( - embeddings.reshape((-1)), embeddings.shape[1] - ) + pyarrow_embeddings = pa.FixedSizeListArray.from_arrays(embeddings.reshape((-1)), embeddings.shape[1]) data_with_embeddings = tb.append_column("_embeddings", pyarrow_embeddings) @@ -1642,9 +1532,7 @@ def send_request(i): data_shard = data.slice(i, shard_size) with io.BytesIO() as buffer: data_shard = data_shard.replace_schema_metadata({"project_id": self.id}) - feather.write_feather( - data_shard, buffer, compression="zstd", compression_level=6 - ) + feather.write_feather(data_shard, buffer, compression="zstd", compression_level=6) buffer.seek(0) response = requests.post( @@ -1663,26 +1551,18 @@ def send_request(i): succeeded = 0 errors_504 = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = { - executor.submit(send_request, i): i - for i in range(0, len(data), shard_size) - } + futures = {executor.submit(send_request, i): i for i in range(0, len(data), shard_size)} while futures: # check for status of the futures which are currently working - done, not_done = concurrent.futures.wait( - futures, return_when=concurrent.futures.FIRST_COMPLETED - ) + done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) # process any completed futures for future in done: response = future.result() if response.status_code != 200: try: logger.error(f"Shard upload failed: {response.text}") - if ( - "more datums exceeds your organization limit" - in response.json() - ): + if "more datums exceeds your organization limit" in response.json(): return False if "Project transaction lock is held" in response.json(): raise Exception( @@ -1693,9 +1573,7 @@ def send_request(i): except (requests.JSONDecodeError, json.decoder.JSONDecodeError): if response.status_code == 413: # Possibly split in two and retry? - logger.error( - "Shard upload failed: you are sending meta-data that is too large." - ) + logger.error("Shard upload failed: you are sending meta-data that is too large.") pbar.update(1) response.close() failed += shard_size @@ -1705,25 +1583,16 @@ def send_request(i): logger.debug( f"{self.identifier}: Connection failed for records {start_point}-{start_point + shard_size}, retrying." ) - failure_fraction = errors_504 / ( - failed + succeeded + errors_504 - ) - if ( - failure_fraction > 0.5 - and errors_504 > shard_size * 3 - ): + failure_fraction = errors_504 / (failed + succeeded + errors_504) + if failure_fraction > 0.5 and errors_504 > shard_size * 3: raise RuntimeError( f"{self.identifier}: Atlas is under high load and cannot ingest datums at this time. Please try again later." ) - new_submission = executor.submit( - send_request, start_point - ) + new_submission = executor.submit(send_request, start_point) futures[new_submission] = start_point response.close() else: - logger.error( - f"{self.identifier}: Shard upload failed: {response}" - ) + logger.error(f"{self.identifier}: Shard upload failed: {response}") failed += shard_size pbar.update(1) response.close() @@ -1775,12 +1644,12 @@ def update_maps( raise ValueError(msg) if embeddings is not None and len(data) != embeddings.shape[0]: - msg = "Expected data and embeddings to be the same length but found lengths {} and {} respectively.".format() + msg = ( + "Expected data and embeddings to be the same length but found lengths {} and {} respectively.".format() + ) raise ValueError(msg) - shard_size = ( - 2000 # TODO someone removed shard size from params and didn't update - ) + shard_size = 2000 # TODO someone removed shard size from params and didn't update # Add new data logger.info("Uploading data to Nomic's neural database Atlas.") with tqdm(total=len(data) // shard_size) as pbar: diff --git a/nomic/embed.py b/nomic/embed.py index d2a9d0be..2c177ae9 100644 --- a/nomic/embed.py +++ b/nomic/embed.py @@ -60,13 +60,21 @@ def request_backoff( return response -def text_api_request(texts: List[str], model: str, task_type: str, dimensionality: int = None, long_text_mode: str = "truncate"): +def text_api_request( + texts: List[str], model: str, task_type: str, dimensionality: int = None, long_text_mode: str = "truncate" +): global atlas_class response = request_backoff( lambda: requests.post( atlas_class.atlas_api_path + "/v1/embedding/text", headers=atlas_class.header, - json={"texts": texts, "model": model, "task_type": task_type, "dimensionality": dimensionality, "long_text_mode": long_text_mode}, + json={ + "texts": texts, + "model": model, + "task_type": task_type, + "dimensionality": dimensionality, + "long_text_mode": long_text_mode, + }, ) ) @@ -137,7 +145,7 @@ def text( device: The device to use for local embeddings. Defaults to CPU, or Metal on Apple Silicon. It can be set to: - "gpu": Use the best available GPU. - "amd", "nvidia": Use the best available GPU from the specified vendor. - - A specific device name from the output of `GPT4All.list_gpus()` + - A specific device name from the output of `GPT4All.list_gpus()` kwargs: Remaining arguments are passed to the Embed4All contructor. Returns: @@ -192,10 +200,17 @@ def _text_atlas( long_text_mode: str, ) -> dict[str, Any]: global atlas_class - assert task_type in ["search_query", "search_document", "classification", "clustering"], f"Invalid task type: {task_type}" + assert task_type in [ + "search_query", + "search_document", + "classification", + "clustering", + ], f"Invalid task type: {task_type}" if dimensionality and dimensionality < MIN_EMBEDDING_DIMENSIONALITY: - logging.warning(f"Dimensionality {dimensionality} is less than the suggested of {MIN_EMBEDDING_DIMENSIONALITY}. Performance may be degraded.") + logging.warning( + f"Dimensionality {dimensionality} is less than the suggested of {MIN_EMBEDDING_DIMENSIONALITY}. Performance may be degraded." + ) if atlas_class is None: atlas_class = AtlasClass() @@ -295,16 +310,16 @@ def image_api_request(images: List[Tuple[str, bytes]], model: str = 'nomic-embed return response.json() else: raise Exception((response.status_code, response.text)) - - + + def resize_pil(img): width, height = img.size - #if image is too large, downsample before sending over the wire + # if image is too large, downsample before sending over the wire max_width = 512 max_height = 512 if max_width > 512 or max_height > 512: - downsize_factor = max(width/max_width, height/max_height) - img.resize((width/downsize_factor, height/downsize_factor)) + downsize_factor = max(width / max_width, height / max_height) + img.resize((width / downsize_factor, height / downsize_factor)) return img @@ -333,10 +348,10 @@ def images(images: Iterable[Union[str, PIL.Image.Image]], model: str = 'nomic-em image_batch = [] for image in images: if isinstance(image, str) and os.path.exists(image): - img = resize_pil(PIL.Image.open(image)) - buffered = BytesIO() - img.save(buffered, format="JPEG") - image_batch.append(("images", buffered.getvalue())) + img = resize_pil(PIL.Image.open(image)) + buffered = BytesIO() + img.save(buffered, format="JPEG") + image_batch.append(("images", buffered.getvalue())) elif isinstance(image, PIL.Image.Image): img = resize_pil(image) @@ -346,7 +361,6 @@ def images(images: Iterable[Union[str, PIL.Image.Image]], model: str = 'nomic-em else: raise ValueError(f"Not a valid file: {image}") - combined = {'embeddings': [], 'usage': {}, 'model': model} with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] diff --git a/nomic/utils.py b/nomic/utils.py index 2a01ba99..7a4acdc3 100644 --- a/nomic/utils.py +++ b/nomic/utils.py @@ -4,12 +4,10 @@ import sys from io import BytesIO from typing import Optional - -import requests -import pyarrow as pa from uuid import UUID import pyarrow as pa +import requests nouns = [ 'newton', @@ -239,6 +237,7 @@ def get_object_size_in_bytes(obj): return sz + # Helpful function for downloading feather files # Best for small feather files def download_feather(url: str, path: str, headers: Optional[dict] = None):