From 1c112ee6c6a393bd751dfd68ef10397f6cc75219 Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Fri, 19 Apr 2024 17:58:28 -0400 Subject: [PATCH 1/5] index deletion --- nomic/dataset.py | 795 ++++++++++++++++++++++++++++------------------- 1 file changed, 481 insertions(+), 314 deletions(-) diff --git a/nomic/dataset.py b/nomic/dataset.py index e4562050..6bbd9a01 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -34,7 +34,13 @@ NomicTopicOptions, convert_pyarrow_schema_for_atlas, ) -from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics +from .data_operations import ( + AtlasMapData, + AtlasMapDuplicates, + AtlasMapEmbeddings, + AtlasMapTags, + AtlasMapTopics, +) from .settings import * from .utils import assert_valid_project_id, get_object_size_in_bytes @@ -46,19 +52,19 @@ def __init__(self): class AtlasClass(object): def __init__(self): - ''' + """ Initializes the Atlas client. - ''' - - if self.credentials['tenant'] == 'staging': - api_hostname = 'staging-api-atlas.nomic.ai' - web_hostname = 'staging-atlas.nomic.ai' - elif self.credentials['tenant'] == 'production': - api_hostname = 'api-atlas.nomic.ai' - web_hostname = 'atlas.nomic.ai' - elif self.credentials['tenant'] == 'enterprise': - api_hostname = self.credentials['api_domain'] - web_hostname = self.credentials['frontend_domain'] + """ + + if self.credentials["tenant"] == "staging": + api_hostname = "staging-api-atlas.nomic.ai" + web_hostname = "staging-atlas.nomic.ai" + elif self.credentials["tenant"] == "production": + api_hostname = "api-atlas.nomic.ai" + web_hostname = "atlas.nomic.ai" + elif self.credentials["tenant"] == "enterprise": + api_hostname = self.credentials["api_domain"] + web_hostname = self.credentials["frontend_domain"] else: raise ValueError("Invalid tenant.") @@ -66,14 +72,14 @@ def __init__(self): self.web_path = f"https://{web_hostname}" try: - override_api_path = os.environ['ATLAS_API_PATH'] + override_api_path = os.environ["ATLAS_API_PATH"] except KeyError: override_api_path = None if override_api_path: self.atlas_api_path = override_api_path - token = self.credentials['token'] + token = self.credentials["token"] self.token = token self.header = {"Authorization": f"Bearer {token}"} @@ -98,7 +104,7 @@ def credentials(self): def _get_current_user(self): api_base_path = self.atlas_api_path - if self.atlas_api_path.startswith('https://api-atlas.nomic.ai'): + if self.atlas_api_path.startswith("https://api-atlas.nomic.ai"): api_base_path = "https://no-cdn-api-atlas.nomic.ai" response = requests.get( @@ -107,39 +113,46 @@ def _get_current_user(self): ) response = validate_api_http_response(response) if not response.status_code == 200: - raise ValueError("Your authorization token is no longer valid. Run `nomic login` to obtain a new one.") + raise ValueError( + "Your authorization token is no longer valid. Run `nomic login` to obtain a new one." + ) return response.json() def _validate_map_data_inputs(self, colorable_fields, id_field, data_sample): - '''Validates inputs to map data calls.''' + """Validates inputs to map data calls.""" if not isinstance(colorable_fields, list): raise ValueError("colorable_fields must be a list of fields") if id_field in colorable_fields: - raise Exception(f'Cannot color by unique id field: {id_field}') + raise Exception(f"Cannot color by unique id field: {id_field}") for field in colorable_fields: if field not in data_sample: - raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.") + raise Exception( + f"Cannot color by field `{field}` as it is not present in the metadata." + ) def _get_current_users_main_organization(self): - ''' + """ Retrieves the ID of the current users default organization. **Returns:** The ID of the current users default organization - ''' + """ user = self._get_current_user() - if user['default_organization']: - for organization in user['organizations']: - if organization['organization_id'] == user['default_organization']: + if user["default_organization"]: + for organization in user["organizations"]: + if organization["organization_id"] == user["default_organization"]: return organization - for organization in user['organizations']: - if organization['user_id'] == user['sub'] and organization['access_role'] == 'OWNER': + for organization in user["organizations"]: + if ( + organization["user_id"] == user["sub"] + and organization["access_role"] == "OWNER" + ): return organization return {} @@ -148,18 +161,18 @@ def _delete_project_by_id(self, project_id): response = requests.post( self.atlas_api_path + "/v1/project/remove", headers=self.header, - json={'project_id': project_id}, + json={"project_id": project_id}, ) def _get_project_by_id(self, project_id: str): - ''' + """ Args: project_id: The project id Returns: Returns the requested dataset. - ''' + """ assert_valid_project_id(project_id) @@ -169,22 +182,24 @@ def _get_project_by_id(self, project_id: str): ) if response.status_code != 200: - raise Exception(f"Could not access dataset with id {project_id}: {response.text}") + raise Exception( + f"Could not access dataset with id {project_id}: {response.text}" + ) return response.json() def _get_organization_by_slug(self, slug: str): - ''' + """ Args: slug: The organization slug Returns: An organization id - ''' + """ - if '/' in slug: - slug = slug.split('/')[0] + if "/" in slug: + slug = slug.split("/")[0] response = requests.get( self.atlas_api_path + f"/v1/organization/{slug}", @@ -193,23 +208,23 @@ def _get_organization_by_slug(self, slug: str): if response.status_code != 200: raise Exception(f"Organization not found: {slug}") - return response.json()['id'] + return response.json()["id"] def _get_dataset_by_slug_identifier(self, identifier: str): - ''' + """ Args: identifier: the organization slug and dataset slug seperated by a slash Returns: Returns the requested dataset. - ''' + """ if not self.is_valid_dataset_identifier(identifier=identifier): raise Exception("Invalid dataset identifier") - organization_slug = identifier.split('/')[0] - project_slug = identifier.split('/')[1] + organization_slug = identifier.split("/")[0] + project_slug = identifier.split("/")[1] response = requests.get( self.atlas_api_path + f"/v1/project/{organization_slug}/{project_slug}", headers=self.header, @@ -221,7 +236,7 @@ def _get_dataset_by_slug_identifier(self, identifier: str): return response.json() def is_valid_dataset_identifier(self, identifier: str): - ''' + """ Checks if a string is a valid identifier for a dataset Args: @@ -229,21 +244,21 @@ def is_valid_dataset_identifier(self, identifier: str): Returns: Returns the requested dataset. - ''' - slugs = identifier.split('/') - if '/' not in identifier or len(slugs) != 2: + """ + slugs = identifier.split("/") + if "/" not in identifier or len(slugs) != 2: return False return True def _get_index_job(self, job_id: str): - ''' + """ Args: job_id: The job id to retrieve the state of. Returns: Job ID meta-data. - ''' + """ response = requests.get( self.atlas_api_path + f"/v1/project/index/job/{job_id}", @@ -251,12 +266,14 @@ def _get_index_job(self, job_id: str): ) if response.status_code != 200: - raise Exception(f'Could not access job state: {response.text}') + raise Exception(f"Could not access job state: {response.text}") return response.json() - def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasDataset") -> pa.Table: - ''' + def _validate_and_correct_arrow_upload( + self, data: pa.Table, project: "AtlasDataset" + ) -> pa.Table: + """ Private method. validates upload data against the dataset arrow schema, and associated other checks. 1. If unique_id_field is specified, validates that each datum has that field. If not, adds it and then notifies the user that it was added. @@ -267,27 +284,28 @@ def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasData Returns: - ''' + """ if not isinstance(data, pa.Table): raise Exception("Invalid data type for upload: {}".format(type(data))) - if project.meta['modality'] == 'text': + if project.meta["modality"] == "text": if "_embeddings" in data: msg = "Can't add embeddings to a text project." raise ValueError(msg) - if project.meta['modality'] == 'embedding': + if project.meta["modality"] == "embedding": if "_embeddings" not in data.column_names: msg = "Must include embeddings in embedding dataset upload." raise ValueError(msg) if project.id_field not in data.column_names: - raise ValueError(f'Data must contain the ID column `{project.id_field}`') + raise ValueError(f"Data must contain the ID column `{project.id_field}`") seen = set() for col in data.column_names: if col.lower() in seen: raise ValueError( - f'Two different fields have the same lowercased name, `{col}`' ': you must use unique column names.' + f"Two different fields have the same lowercased name, `{col}`" + ": you must use unique column names." ) seen.add(col.lower()) @@ -319,47 +337,63 @@ def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasData logger.warning( f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version." ) - reformatted[field.name] = pc.fill_null(reformatted[field.name], "null") - if pa.compute.any(pa.compute.equal(pa.compute.binary_length(reformatted[field.name]), 0)): # type: ignore - mask = pa.compute.equal(pa.compute.binary_length(reformatted[field.name]), 0).combine_chunks() # type: ignore - assert pa.types.is_boolean(mask.type) # type: ignore - reformatted[field.name] = pa.compute.replace_with_mask(reformatted[field.name], mask, "null") # type: ignore + reformatted[field.name] = pc.fill_null( + reformatted[field.name], "null" + ) + if pc.any(pc.equal(pc.binary_length(reformatted[field.name]), 0)): + mask = pc.equal( + pc.binary_length(reformatted[field.name]), 0 + ).combine_chunks() + assert pa.types.is_boolean(mask.type) + reformatted[field.name] = pc.replace_with_mask( + reformatted[field.name], mask, "null" + ) for field in data.schema: if not field.name in reformatted: if field.name == "_embeddings": - reformatted['_embeddings'] = data['_embeddings'] + reformatted["_embeddings"] = data["_embeddings"] else: - logger.warning(f"Field {field.name} present in data, but not found in table schema. Ignoring.") + logger.warning( + f"Field {field.name} present in data, but not found in table schema. Ignoring." + ) data = pa.Table.from_pydict(reformatted, schema=project.schema) - if project.meta['insert_update_delete_lock']: - raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.") + if project.meta["insert_update_delete_lock"]: + raise Exception( + "Project is currently indexing and cannot ingest new datums. Try again later." + ) # The following two conditions should never occur given the above, but just in case... - assert project.id_field in data.column_names, f"Upload does not contain your specified id_field" + assert ( + project.id_field in data.column_names + ), f"Upload does not contain your specified id_field" if not pa.types.is_string(data[project.id_field].type): - logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}") + logger.warning( + f"id_field is not a string. Converting to string from {data[project.id_field].type}" + ) data = data.drop([project.id_field]).append_column( project.id_field, data[project.id_field].cast(pa.string()) ) for key in data.column_names: - if key.startswith('_'): - if key == '_embeddings': + if key.startswith("_"): + if key == "_embeddings": continue - raise ValueError('Metadata fields cannot start with _') - if pa.compute.max(pa.compute.utf8_length(data[project.id_field])).as_py() > 36: # type: ignore + raise ValueError("Metadata fields cannot start with _") + if pc.max(pc.utf8_length(data[project.id_field])).as_py() > 36: first_match = data.filter( - pa.compute.greater(pa.compute.utf8_length(data[project.id_field]), 36) # type: ignore + pc.greater(pc.utf8_length(data[project.id_field]), 36) ).to_pylist()[0][project.id_field] raise ValueError( f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters." ) return data - def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]: - ''' + def _get_organization( + self, organization_slug=None, organization_id=None + ) -> Tuple[str, str]: + """ Gets an organization by either its name or id. Args: @@ -369,22 +403,26 @@ def _get_organization(self, organization_slug=None, organization_id=None) -> Tup Returns: The organization_slug and organization_id if one was found. - ''' + """ if organization_slug is None: - if organization_id is None: # default to current users organization (the one with their name) + if ( + organization_id is None + ): # default to current users organization (the one with their name) organization = self._get_current_users_main_organization() - organization_slug = organization['slug'] - organization_id = organization['organization_id'] + organization_slug = organization["slug"] + organization_id = organization["organization_id"] else: - raise NotImplementedError("Getting organization by a specific ID is not yet implemented.") + raise NotImplementedError( + "Getting organization by a specific ID is not yet implemented." + ) else: try: organization_id = self._get_organization_by_slug(slug=organization_slug) except Exception: user = self._get_current_user() - users_organizations = [org['slug'] for org in user['organizations']] + users_organizations = [org["slug"] for org in user["organizations"]] raise Exception( f"No such organization exists: {organization_slug}. You have access to the following organizations: {users_organizations}" ) @@ -400,26 +438,48 @@ class AtlasIndex: the points in the index that you can browse online. """ - def __init__(self, atlas_index_id, name, indexed_field, projections): - '''Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.''' + def __init__( + self, atlas_index_id, name, indexed_field, projections, dataset: AtlasDataset + ): + """Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.""" self.id = atlas_index_id self.name = name self.indexed_field = indexed_field self.projections = projections + self.dataset = dataset def _repr_html_(self): - return '
'.join([d._repr_html_() for d in self.projections]) + return "
".join([d._repr_html_() for d in self.projections]) + + def delete(self): + """ + Deletes an atlas index with all associated metadata. + """ + response = requests.post( + self.dataset.atlas_api_path + f"/v1/project/index/remove", + headers=self.dataset.header, + json={"index_id": self.id, "project_id": self.dataset.id}, + ) + if not response.status_code == 200: + raise Exception(f"Failed to delete index: {response.text}") class AtlasProjection: - ''' + """ Interact and access state of an Atlas Map including text/vector search. This class should not be instantiated directly. Instead instantiate an AtlasDataset and use the dataset.maps attribute to retrieve an AtlasProjection. - ''' + """ - def __init__(self, dataset: "AtlasDataset", atlas_index_id: str, projection_id: str, name): + def __init__( + self, + dataset: "AtlasDataset", + atlas_index_id: str, + projection_id: str, + name, + index: AtlasIndex, + ): """ Creates an AtlasProjection. """ @@ -435,19 +495,21 @@ def __init__(self, dataset: "AtlasDataset", atlas_index_id: str, projection_id: self._tile_data = None self._data = None self._schema = None + self.index = index @property def map_link(self): - ''' + """ Retrieves a map link. - ''' + """ return f"{self.dataset.web_path}/data/{self.dataset.meta['organization_slug']}/{self.dataset.meta['slug']}/map" # return f"{self.project.web_path}/data/{self.project.meta['organization_slug']}/{self.project.meta['slug']}/map" @property def _status(self): response = requests.get( - self.dataset.atlas_api_path + f"/v1/project/index/job/progress/{self.atlas_index_id}", + self.dataset.atlas_api_path + + f"/v1/project/index/job/progress/{self.atlas_index_id}", headers=self.dataset.header, ) if response.status_code != 200: @@ -512,8 +574,8 @@ def _embed_html(self): def _repr_html_(self): # Don't make an iframe if the dataset is locked. - state = self._status['index_build_stage'] - if state != 'Completed': + state = self._status["index_build_stage"] + if state != "Completed": return f"""Atlas Projection {self.name}. Status {state}. view online""" return f"""

Project: {self.dataset.slug}

@@ -524,7 +586,9 @@ def _repr_html_(self): def duplicates(self): """Duplicate detection state""" if self.dataset.is_locked: - raise Exception('Dataset is locked! Please wait until the dataset is unlocked to access duplicates.') + raise Exception( + "Dataset is locked! Please wait until the dataset is unlocked to access duplicates." + ) if self._duplicates is None: self._duplicates = AtlasMapDuplicates(self) return self._duplicates @@ -534,7 +598,7 @@ def topics(self): """Topic state""" if self.dataset.is_locked: raise Exception( - 'Dataset is locked for state access! Please wait until the dataset is unlocked to access topics.' + "Dataset is locked for state access! Please wait until the dataset is unlocked to access topics." ) if self._topics is None: self._topics = AtlasMapTopics(self) @@ -545,7 +609,7 @@ def embeddings(self): """Embedding state""" if self.dataset.is_locked: raise Exception( - 'Dataset is locked for state access! Please wait until the dataset is unlocked to access embeddings.' + "Dataset is locked for state access! Please wait until the dataset is unlocked to access embeddings." ) if self._embeddings is None: self._embeddings = AtlasMapEmbeddings(self) @@ -556,7 +620,7 @@ def tags(self): """Tag state""" if self.dataset.is_locked: raise Exception( - 'Dataset is locked for state access! Please wait until the dataset is unlocked to access tags.' + "Dataset is locked for state access! Please wait until the dataset is unlocked to access tags." ) if self._tags is None: self._tags = AtlasMapTags(self) @@ -567,7 +631,7 @@ def data(self): """Metadata state""" if self.dataset.is_locked: raise Exception( - 'Dataset is locked for state access! Please wait until the dataset is unlocked to access data.' + "Dataset is locked for state access! Please wait until the dataset is unlocked to access data." ) if self._data is None: self._data = AtlasMapData(self) @@ -578,11 +642,12 @@ def schema(self): """Projection arrow schema""" if self.dataset.is_locked: raise Exception( - 'Dataset is locked for state access! Please wait until the dataset is unlocked to access data.' + "Dataset is locked for state access! Please wait until the dataset is unlocked to access data." ) if self._schema is None: response = requests.get( - self.dataset.atlas_api_path + f"/v1/project/projection/{self.projection_id}/schema", + self.dataset.atlas_api_path + + f"/v1/project/projection/{self.projection_id}/schema", headers=self.dataset.header, ) if response.status_code != 200: @@ -596,7 +661,7 @@ def _registered_sidecars(self) -> List[Tuple[str, str]]: "Returns [(field_name, sidecar_name), ...]" sidecars = [] for field in self.schema: - sidecar_name = json.loads(field.metadata.get(b'sidecar_name', b'""')) + sidecar_name = json.loads(field.metadata.get(b"sidecar_name", b'""')) if sidecar_name: sidecars.append((field.name, sidecar_name)) return sidecars @@ -615,27 +680,34 @@ def _fetch_tiles(self, overwrite: bool = True): return self._tile_data self._download_large_feather(overwrite=overwrite) tbs = [] - root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True) + root = feather.read_table( + self.tile_destination / "0/0/0.feather", memory_map=True + ) try: - sidecars = set([v for k, v in json.loads(root.schema.metadata[b'sidecars']).items()]) + sidecars = set( + [v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()] + ) except KeyError: sidecars = set([]) - sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) + sidecars |= set( + sidecar_name for (_, sidecar_name) in self._registered_sidecars() + ) for path in self._tiles_in_order(): - if isinstance(path, Path): - tb = pa.feather.read_table(path, memory_map=True) # type: ignore - for sidecar_file in sidecars: - carfile = pa.feather.read_table( # type: ignore - path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True - ) - for col in carfile.column_names: - tb = tb.append_column(col, carfile[col]) - tbs.append(tb) + tb = pa.feather.read_table(path, memory_map=True) + for sidecar_file in sidecars: + carfile = pa.feather.read_table( + path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True + ) + for col in carfile.column_names: + tb = tb.append_column(col, carfile[col]) + tbs.append(tb) self._tile_data = pa.concat_tables(tbs) return self._tile_data - def _tiles_in_order(self, coords_only=False) -> Iterator[Union[Tuple[int, int, int], Path]]: + def _tiles_in_order( + self, coords_only=False + ) -> Iterator[Union[Tuple[int, int, int], Path]]: """ Returns: A list of all tiles in the projection in a fixed order so that all @@ -656,7 +728,9 @@ def children(z, x, y): # Pop off the front, extend the back (breadth first traversal) while len(paths) > 0: z, x, y = paths.pop(0) - path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather") + path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix( + ".feather" + ) if path.exists(): if coords_only: yield (z, x, y) @@ -668,22 +742,26 @@ def children(z, x, y): def tile_destination(self): return Path("~/.nomic/cache", self.id).expanduser() - def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True): - ''' + def _download_large_feather( + self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True + ): + """ Downloads the feather tree. Args: overwrite: if True then overwrite existing feather files. Returns: A list containing all quadtiles downloads. - ''' + """ # TODO: change overwrite default to False once updating projection is removed. - quads = [f'0/0/0'] + quads = [f"0/0/0"] self.tile_destination.mkdir(parents=True, exist_ok=True) - root = f'{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/' + root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/" all_quads = [] sidecars = None - registered_sidecars = set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) + registered_sidecars = set( + sidecar_name for (_, sidecar_name) in self._registered_sidecars() + ) while len(quads) > 0: rawquad = quads.pop(0) quad = rawquad + ".feather" @@ -708,22 +786,24 @@ def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overw except pa.ArrowInvalid: path.unlink(missing_ok=True) - if not download_success or schema is None: + if not download_success: raise Exception(f"Failed to download tiles. Aborting...") - if sidecars is None and b'sidecars' in schema.metadata: + if sidecars is None and b"sidecars" in schema.metadata: # Grab just the filenames - sidecars = set([v for k, v in json.loads(schema.metadata.get(b'sidecars')).items()]) + sidecars = set( + [v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()] + ) elif sidecars is None: sidecars = set() if not "." in rawquad: for sidecar in sidecars | registered_sidecars: # The sidecar loses the feather suffix because it's supposed to be raw. - quads.append(quad.replace(".feather", f'.{sidecar}')) - if not schema.metadata or b'children' not in schema.metadata: + quads.append(quad.replace(".feather", f".{sidecar}")) + if not schema.metadata or b"children" not in schema.metadata: # Sidecars don't have children. continue - kids = schema.metadata.get(b'children') + kids = schema.metadata.get(b"children") children = json.loads(kids) quads.extend(children) return all_quads @@ -733,7 +813,7 @@ def datum_id_field(self): return self.dataset.meta["unique_id_field"] def _get_atoms(self, ids: List[str]) -> List[Dict]: - ''' + """ Retrieves atoms by id Args: @@ -742,7 +822,7 @@ def _get_atoms(self, ids: List[str]) -> List[Dict]: Returns: A dictionary containing the resulting atoms, keyed by atom id. - ''' + """ if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") @@ -750,20 +830,26 @@ def _get_atoms(self, ids: List[str]) -> List[Dict]: response = requests.post( self.dataset.atlas_api_path + "/v1/project/atoms/get", headers=self.dataset.header, - json={'project_id': self.dataset.id, 'index_id': self.atlas_index_id, 'atom_ids': ids}, + json={ + "project_id": self.dataset.id, + "index_id": self.atlas_index_id, + "atom_ids": ids, + }, ) if response.status_code == 200: - return response.json()['atoms'] + return response.json()["atoms"] else: raise Exception(response.text) class AtlasDataStream(AtlasClass): - def __init__(self, name: Optional[str] = 'contrastors'): + def __init__(self, name: Optional[str] = "contrastors"): super().__init__() - if name != 'contrastors': - raise NotImplementedError("Only contrastors datastream is currently supported") + if name != "contrastors": + raise NotImplementedError( + "Only contrastors datastream is currently supported" + ) self.name = name # TODO: add support for other datastreams @@ -783,7 +869,7 @@ class AtlasDataset(AtlasClass): def __init__( self, identifier: Optional[str] = None, - description: Optional[str] = 'A description for your map.', + description: Optional[str] = "A description for your map.", unique_id_field: Optional[str] = None, is_public: bool = True, dataset_id=None, @@ -802,7 +888,9 @@ def __init__( * **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view. * **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists. """ - assert identifier is not None or dataset_id is not None, "You must pass a dataset identifier" + assert ( + identifier is not None or dataset_id is not None + ), "You must pass a dataset identifier" super().__init__() @@ -815,18 +903,20 @@ def __init__( self.meta = self._get_project_by_id(dataset_id) return - if not self.is_valid_dataset_identifier(identifier=str(identifier)): - default_org_slug = self._get_current_users_main_organization()['slug'] - identifier = default_org_slug + '/' + identifier + if not self.is_valid_dataset_identifier(identifier=identifier): + default_org_slug = self._get_current_users_main_organization()["slug"] + identifier = default_org_slug + "/" + identifier dataset = self._get_dataset_by_slug_identifier(identifier=str(identifier)) if dataset: # dataset already exists - logger.info(f"Loading existing dataset `{identifier}`.") - dataset_id = dataset['id'] + logger.info(f"Loading existing dataset `{identifier}``.") + dataset_id = dataset["id"] if dataset_id is None: # if there is no existing project, make a new one. - if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project + if ( + unique_id_field is None + ): # if not all parameters are specified, we weren't trying to make a project raise ValueError(f"Dataset `{identifier}` does not exist.") # if modality is None: @@ -846,13 +936,15 @@ def __init__( self._schema = None def delete(self): - ''' + """ Deletes an atlas dataset with all associated metadata. - ''' + """ organization = self._get_current_users_main_organization() - organization_slug = organization['slug'] + organization_slug = organization["slug"] - logger.info(f"Deleting dataset `{self.slug}` from organization `{organization_slug}`") + logger.info( + f"Deleting dataset `{self.slug}` from organization `{organization_slug}`" + ) self._delete_project_by_id(project_id=self.id) @@ -865,7 +957,7 @@ def _create_project( unique_id_field: str, is_public: bool = True, ): - ''' + """ Creates an Atlas Dataset. Atlas Datasets store data (text, embeddings, etc) that you can organize by building indices. If the organization already contains a dataset with this name, it will be returned instead. @@ -879,10 +971,10 @@ def _create_project( **Returns:** project_id on success. - ''' + """ organization_id = self._get_organization_by_slug(slug=identifier) - project_slug = identifier.split('/')[1] + project_slug = identifier.split("/")[1] # supported_modalities = ['text', 'embedding'] # if modality not in supported_modalities: @@ -899,12 +991,12 @@ def _create_project( self.atlas_api_path + "/v1/project/create", headers=self.header, json={ - 'organization_id': organization_id, - 'project_name': project_slug, - 'description': description, - 'unique_id_field': unique_id_field, + "organization_id": organization_id, + "project_name": project_slug, + "description": description, + "unique_id_field": unique_id_field, # 'modality': modality, - 'is_public': is_public, + "is_public": is_public, }, ) @@ -913,12 +1005,12 @@ def _create_project( logger.info(f"Creating dataset `{response.json()['slug']}`") - return response.json()['project_id'] + return response.json()["project_id"] def _latest_dataset_state(self): - ''' + """ Refreshes the project's state. Try to call this sparingly but use it when you need it. - ''' + """ self.meta = self._get_project_by_id(self.id) return self @@ -927,20 +1019,26 @@ def _latest_dataset_state(self): def indices(self) -> List[AtlasIndex]: self._latest_dataset_state() output = [] - for index in self.meta['atlas_indices']: + for index_info in self.meta["atlas_indices"]: projections = [] - for projection in index['projections']: - projection = AtlasProjection( - dataset=self, projection_id=projection['id'], atlas_index_id=index['id'], name=index['index_name'] - ) - projections.append(projection) index = AtlasIndex( - atlas_index_id=index['id'], - name=index['index_name'], - indexed_field=index['indexed_field'], + atlas_index_id=index_info["id"], + name=index_info["index_name"], + indexed_field=index_info["indexed_field"], projections=projections, + dataset=self, ) - output.append(index) + for projection in index_info["projections"]: + projection = AtlasProjection( + dataset=self, + projection_id=projection["id"], + atlas_index_id=index_info["id"], + name=index_info["index_name"], + index=index, + ) + projections.append(projection) + + output.append(index_info) return output @@ -958,76 +1056,80 @@ def maps(self) -> List[AtlasProjection]: @property def id(self) -> str: - '''The UUID of the dataset.''' - return self.meta['id'] + """The UUID of the dataset.""" + return self.meta["id"] @property def id_field(self) -> str: - return self.meta['unique_id_field'] + return self.meta["unique_id_field"] @property def created_timestamp(self) -> datetime: - return datetime.strptime(self.meta['created_timestamp'], '%Y-%m-%dT%H:%M:%S.%f%z') + return datetime.strptime( + self.meta["created_timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z" + ) @property def total_datums(self) -> int: - '''The total number of data points in the dataset.''' - return self.meta['total_datums_in_project'] + """The total number of data points in the dataset.""" + return self.meta["total_datums_in_project"] @property def modality(self) -> str: - return self.meta['modality'] + return self.meta["modality"] @property def name(self) -> str: - '''The customizable name of the dataset.''' - return self.meta['project_name'] + """The customizable name of the dataset.""" + return self.meta["project_name"] @property def slug(self) -> str: - '''The URL-safe identifier for this dataset.''' - return self.meta['slug'] + """The URL-safe identifier for this dataset.""" + return self.meta["slug"] @property def identifier(self) -> str: - '''The Atlas globally unique, URL-safe identifier for this dataset''' - return self.meta['organization_slug'] + '/' + self.meta['slug'] + """The Atlas globally unique, URL-safe identifier for this dataset""" + return self.meta["organization_slug"] + "/" + self.meta["slug"] @property def description(self): - return self.meta['description'] + return self.meta["description"] @property def dataset_fields(self): - return self.meta['project_fields'] + return self.meta["project_fields"] @property def is_locked(self) -> bool: self._latest_dataset_state() - return self.meta['insert_update_delete_lock'] + return self.meta["insert_update_delete_lock"] @property def schema(self) -> Optional[pa.Schema]: if self._schema is not None: return self._schema - if 'schema' in self.meta and self.meta['schema'] is not None: - self._schema: pa.Schema = ipc.read_schema(io.BytesIO(base64.b64decode(self.meta['schema']))) + if "schema" in self.meta and self.meta["schema"] is not None: + self._schema: pa.Schema = ipc.read_schema( + io.BytesIO(base64.b64decode(self.meta["schema"])) + ) return self._schema return None @property def is_accepting_data(self) -> bool: - ''' + """ Checks if the dataset can accept data. Datasets cannot accept data when they are being indexed. Returns: True if dataset is unlocked for data additions, false otherwise. - ''' + """ return not self.is_locked @contextmanager def wait_for_dataset_lock(self): - '''Blocks thread execution until dataset is in a state where it can ingest data.''' + """Blocks thread execution until dataset is in a state where it can ingest data.""" has_logged = False while True: if self.is_accepting_data: @@ -1039,9 +1141,9 @@ def wait_for_dataset_lock(self): time.sleep(5) def get_map( - self, name: Optional[str] = None, atlas_index_id: Optional[str] = None, projection_id: Optional[str] = None + self, name: str = None, atlas_index_id: str = None, projection_id: str = None ) -> AtlasProjection: - ''' + """ Retrieves a map. Args: @@ -1051,7 +1153,7 @@ def get_map( Returns: The map or a ValueError. - ''' + """ indices = self.indices @@ -1059,16 +1161,22 @@ def get_map( for index in indices: if index.id == atlas_index_id: if len(index.projections) == 0: - raise ValueError(f"No map found under index with atlas_index_id='{atlas_index_id}'") + raise ValueError( + f"No map found under index with atlas_index_id='{atlas_index_id}'" + ) return index.projections[0] - raise ValueError(f"Could not find a map with atlas_index_id='{atlas_index_id}'") + raise ValueError( + f"Could not find a map with atlas_index_id='{atlas_index_id}'" + ) if projection_id is not None: for index in indices: for projection in index.projections: if projection.id == projection_id: return projection - raise ValueError(f"Could not find a map with projection_id='{atlas_index_id}'") + raise ValueError( + f"Could not find a map with projection_id='{atlas_index_id}'" + ) if len(indices) == 0: raise ValueError("You have no maps built in your project") @@ -1094,9 +1202,9 @@ def create_index( topic_model: Union[bool, Dict, NomicTopicOptions] = True, duplicate_detection: Union[bool, Dict, NomicDuplicatesOptions] = True, embedding_model: Optional[Union[str, Dict, NomicEmbedOptions]] = None, - reuse_embeddings_from_index: Optional[str] = None, - ) -> Union[AtlasProjection, None]: - ''' + reuse_embeddings_from_index: str = None, + ) -> AtlasProjection: + """ Creates an index in the specified dataset. Args: @@ -1112,7 +1220,7 @@ def create_index( Returns: The projection this index has built. - ''' + """ self._latest_dataset_state() @@ -1151,53 +1259,56 @@ def create_index( colorable_fields = [] for field in self.dataset_fields: - if field not in [self.id_field, indexed_field] and not field.startswith("_"): + if field not in [self.id_field, indexed_field] and not field.startswith( + "_" + ): colorable_fields.append(field) - build_template = {} - if self.modality == 'embedding': + if self.modality == "embedding": if topic_model.community_description_target_field is None: logger.warning( "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics." ) build_template = { - 'project_id': self.id, - 'index_name': name, - 'indexed_field': None, - 'atomizer_strategies': None, - 'model': None, - 'colorable_fields': colorable_fields, - 'model_hyperparameters': None, - 'nearest_neighbor_index': 'HNSWIndex', - 'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}), - 'projection': 'NomicProject', - 'projection_hyperparameters': json.dumps( + "project_id": self.id, + "index_name": name, + "indexed_field": None, + "atomizer_strategies": None, + "model": None, + "colorable_fields": colorable_fields, + "model_hyperparameters": None, + "nearest_neighbor_index": "HNSWIndex", + "nearest_neighbor_index_hyperparameters": json.dumps( + {"space": "l2", "ef_construction": 100, "M": 16} + ), + "projection": "NomicProject", + "projection_hyperparameters": json.dumps( { - 'n_neighbors': projection.n_neighbors, - 'n_epochs': projection.n_epochs, - 'spread': projection.spread, - 'local_neighborhood_size': projection.local_neighborhood_size, - 'rho': projection.rho, - 'model': projection.model, + "n_neighbors": projection.n_neighbors, + "n_epochs": projection.n_epochs, + "spread": projection.spread, + "local_neighborhood_size": projection.local_neighborhood_size, + "rho": projection.rho, + "model": projection.model, } ), - 'topic_model_hyperparameters': json.dumps( + "topic_model_hyperparameters": json.dumps( { - 'build_topic_model': topic_model.build_topic_model, - 'community_description_target_field': topic_model.community_description_target_field, - 'cluster_method': topic_model.cluster_method, - 'enforce_topic_hierarchy': topic_model.enforce_topic_hierarchy, + "build_topic_model": topic_model.build_topic_model, + "community_description_target_field": topic_model.community_description_target_field, + "cluster_method": topic_model.cluster_method, + "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy, } ), - 'duplicate_detection_hyperparameters': json.dumps( + "duplicate_detection_hyperparameters": json.dumps( { - 'tag_duplicates': duplicate_detection.tag_duplicates, - 'duplicate_cutoff': duplicate_detection.duplicate_cutoff, + "tag_duplicates": duplicate_detection.tag_duplicates, + "duplicate_cutoff": duplicate_detection.duplicate_cutoff, } ), } - elif self.modality == 'text': + elif self.modality == "text": # find the index id of the index with name reuse_embeddings_from_index reuse_embedding_from_index_id = None indices = self.indices @@ -1212,52 +1323,58 @@ def create_index( ) if indexed_field is None: - raise Exception("You did not specify a field to index. Specify an 'indexed_field'.") + raise Exception( + "You did not specify a field to index. Specify an 'indexed_field'." + ) if indexed_field not in self.dataset_fields: - raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}") + raise Exception( + f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}" + ) build_template = { - 'project_id': self.id, - 'index_name': name, - 'indexed_field': indexed_field, - 'atomizer_strategies': ['document', 'charchunk'], - 'model': embedding_model.model, - 'colorable_fields': colorable_fields, - 'reuse_atoms_and_embeddings_from': reuse_embedding_from_index_id, - 'model_hyperparameters': json.dumps( + "project_id": self.id, + "index_name": name, + "indexed_field": indexed_field, + "atomizer_strategies": ["document", "charchunk"], + "model": embedding_model.model, + "colorable_fields": colorable_fields, + "reuse_atoms_and_embeddings_from": reuse_embedding_from_index_id, + "model_hyperparameters": json.dumps( { - 'dataset_buffer_size': 1000, - 'batch_size': 20, - 'polymerize_by': 'charchunk', - 'norm': 'both', + "dataset_buffer_size": 1000, + "batch_size": 20, + "polymerize_by": "charchunk", + "norm": "both", } ), - 'nearest_neighbor_index': 'HNSWIndex', - 'nearest_neighbor_index_hyperparameters': json.dumps({'space': 'l2', 'ef_construction': 100, 'M': 16}), - 'projection': 'NomicProject', - 'projection_hyperparameters': json.dumps( + "nearest_neighbor_index": "HNSWIndex", + "nearest_neighbor_index_hyperparameters": json.dumps( + {"space": "l2", "ef_construction": 100, "M": 16} + ), + "projection": "NomicProject", + "projection_hyperparameters": json.dumps( { - 'n_neighbors': projection.n_neighbors, - 'n_epochs': projection.n_epochs, - 'spread': projection.spread, - 'local_neighborhood_size': projection.local_neighborhood_size, - 'rho': projection.rho, - 'model': projection.model, + "n_neighbors": projection.n_neighbors, + "n_epochs": projection.n_epochs, + "spread": projection.spread, + "local_neighborhood_size": projection.local_neighborhood_size, + "rho": projection.rho, + "model": projection.model, } ), - 'topic_model_hyperparameters': json.dumps( + "topic_model_hyperparameters": json.dumps( { - 'build_topic_model': topic_model.build_topic_model, - 'community_description_target_field': indexed_field, - 'cluster_method': topic_model.build_topic_model, - 'enforce_topic_hierarchy': topic_model.enforce_topic_hierarchy, + "build_topic_model": topic_model.build_topic_model, + "community_description_target_field": indexed_field, + "cluster_method": topic_model.build_topic_model, + "enforce_topic_hierarchy": topic_model.enforce_topic_hierarchy, } ), - 'duplicate_detection_hyperparameters': json.dumps( + "duplicate_detection_hyperparameters": json.dumps( { - 'tag_duplicates': duplicate_detection.tag_duplicates, - 'duplicate_cutoff': duplicate_detection.duplicate_cutoff, + "tag_duplicates": duplicate_detection.tag_duplicates, + "duplicate_cutoff": duplicate_detection.duplicate_cutoff, } ), } @@ -1268,18 +1385,20 @@ def create_index( json=build_template, ) if response.status_code != 200: - logger.info('Create dataset failed with code: {}'.format(response.status_code)) - logger.info('Additional info: {}'.format(response.text)) - raise Exception(response.json()['detail']) + logger.info( + "Create dataset failed with code: {}".format(response.status_code) + ) + logger.info("Additional info: {}".format(response.text)) + raise Exception(response.json()["detail"]) - job_id = response.json()['job_id'] + job_id = response.json()["job_id"] job = requests.get( self.atlas_api_path + f"/v1/project/index/job/{job_id}", headers=self.header, ).json() - index_id = job['index_id'] + index_id = job["index_id"] try: atlas_projection = self.get_map(atlas_index_id=index_id) @@ -1293,11 +1412,10 @@ def create_index( if atlas_projection is None: logger.warning("Could not find a map being built for this dataset.") - else: - logger.info( - f"Created map `{atlas_projection.name}` in dataset `{self.identifier}`: {atlas_projection.map_link}" - ) - return atlas_projection + logger.info( + f"Created map `{projection.name}` in dataset `{self.identifier}`: {projection.map_link}" + ) + return projection def __repr__(self): m = self.meta @@ -1318,8 +1436,8 @@ def _repr_html_(self): html += "
Projections\n" html += "" @@ -1330,10 +1448,16 @@ def _repr_html_(self): return html def __str__(self): - return "\n".join([str(projection) for index in self.indices for projection in index.projections]) + return "\n".join( + [ + str(projection) + for index in self.indices + for projection in index.projections + ] + ) def get_data(self, ids: List[str]) -> List[Dict]: - ''' + """ Retrieve the contents of the data given ids. Args: @@ -1342,25 +1466,27 @@ def get_data(self, ids: List[str]) -> List[Dict]: Returns: A list of dictionaries corresponding to the data. - ''' + """ if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") if isinstance(ids[0], list): - raise ValueError("You must specify a list of ids when getting data, not a nested list.") + raise ValueError( + "You must specify a list of ids when getting data, not a nested list." + ) response = requests.post( self.atlas_api_path + "/v1/project/data/get", headers=self.header, - json={'project_id': self.id, 'datum_ids': ids}, + json={"project_id": self.id, "datum_ids": ids}, ) if response.status_code == 200: - return [item for item in response.json()['datums']] + return [item for item in response.json()["datums"]] else: raise Exception(response.text) def delete_data(self, ids: List[str]) -> bool: - ''' + """ Deletes the specified datapoints from the dataset. Args: @@ -1369,14 +1495,14 @@ def delete_data(self, ids: List[str]) -> bool: Returns: True if data deleted successfully. - ''' + """ if not isinstance(ids, list): raise ValueError("You must specify a list of ids when deleting datums.") response = requests.post( self.atlas_api_path + "/v1/project/data/delete", headers=self.header, - json={'project_id': self.id, 'datum_ids': ids}, + json={"project_id": self.id, "datum_ids": ids}, ) if response.status_code == 200: @@ -1384,7 +1510,12 @@ def delete_data(self, ids: List[str]) -> bool: else: raise Exception(response.text) - def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Optional[np.ndarray] = None, pbar=None): + def add_data( + self, + data=Union[DataFrame, List[Dict], pa.Table], + embeddings: np.array = None, + pbar=None, + ): """ Adds data of varying modality to an Atlas dataset. Args: @@ -1392,10 +1523,9 @@ def add_data(self, data=Union[DataFrame, List[Dict], pa.Table], embeddings: Opti embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints. pbar: (Optional). A tqdm progress bar to update. """ - if embeddings is not None: - self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) - elif isinstance(data, pa.Table) and "_embeddings" in data.column_names: # type: ignore - embeddings = np.array(data.column('_embeddings').to_pylist()) # type: ignore + if embeddings is not None or ( + isinstance(data, pa.Table) and "_embeddings" in data.column_names + ): self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) else: self._add_text(data=data, pbar=pbar) @@ -1411,10 +1541,17 @@ def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None): elif isinstance(data, list): data = pa.Table.from_pylist(data) elif not isinstance(data, pa.Table): - raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.") + raise ValueError( + "Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table." + ) self._add_data(data, pbar=pbar) - def _add_embeddings(self, data: Union[DataFrame, List[Dict], pa.Table], embeddings: np.ndarray, pbar=None): + def _add_embeddings( + self, + data: Union[DataFrame, List[Dict], pa.Table, None], + embeddings: np.array, + pbar=None, + ): """ Add data, with associated embeddings, to the dataset. @@ -1430,7 +1567,9 @@ def _add_embeddings(self, data: Union[DataFrame, List[Dict], pa.Table], embeddin """ assert type(embeddings) == np.ndarray, "Embeddings must be a NumPy array." assert len(embeddings.shape) == 2, "Embeddings must be a 2D NumPy array." - assert len(data) == embeddings.shape[0], "Data and embeddings must have the same number of rows." + assert ( + len(data) == embeddings.shape[0] + ), "Data and embeddings must have the same number of rows." assert len(data) > 0, "Data must have at least one row." tb: pa.Table @@ -1456,7 +1595,9 @@ def _add_embeddings(self, data: Union[DataFrame, List[Dict], pa.Table], embeddin assert not np.isnan(embeddings).any(), "Embeddings must not contain NaN values." assert not np.isinf(embeddings).any(), "Embeddings must not contain Inf values." - pyarrow_embeddings = pa.FixedSizeListArray.from_arrays(embeddings.reshape((-1)), embeddings.shape[1]) + pyarrow_embeddings = pa.FixedSizeListArray.from_arrays( + embeddings.reshape((-1)), embeddings.shape[1] + ) data_with_embeddings = tb.append_column("_embeddings", pyarrow_embeddings) @@ -1467,7 +1608,7 @@ def _add_data( data: pa.Table, pbar=None, ): - ''' + """ Low level interface to upload an Arrow Table. Users should generally call 'add_text' or 'add_embeddings.' Args: @@ -1475,7 +1616,7 @@ def _add_data( pbar: A tqdm progress bar to update. Returns: None - ''' + """ # Exactly 10 upload workers at a time. @@ -1506,8 +1647,10 @@ def _add_data( def send_request(i): data_shard = data.slice(i, shard_size) with io.BytesIO() as buffer: - data_shard = data_shard.replace_schema_metadata({'project_id': self.id}) - feather.write_feather(data_shard, buffer, compression='zstd', compression_level=6) + data_shard = data_shard.replace_schema_metadata({"project_id": self.id}) + feather.write_feather( + data_shard, buffer, compression="zstd", compression_level=6 + ) buffer.seek(0) response = requests.post( @@ -1526,29 +1669,39 @@ def send_request(i): succeeded = 0 errors_504 = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = {executor.submit(send_request, i): i for i in range(0, len(data), shard_size)} + futures = { + executor.submit(send_request, i): i + for i in range(0, len(data), shard_size) + } while futures: # check for status of the futures which are currently working - done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) + done, not_done = concurrent.futures.wait( + futures, return_when=concurrent.futures.FIRST_COMPLETED + ) # process any completed futures for future in done: response = future.result() if response.status_code != 200: try: logger.error(f"Shard upload failed: {response.text}") - if 'more datums exceeds your organization limit' in response.json(): + if ( + "more datums exceeds your organization limit" + in response.json() + ): return False - if 'Project transaction lock is held' in response.json(): + if "Project transaction lock is held" in response.json(): raise Exception( "Project is currently indexing and cannot ingest new datums. Try again later." ) - if 'Insert failed due to ID conflict' in response.json(): + if "Insert failed due to ID conflict" in response.json(): continue except (requests.JSONDecodeError, json.decoder.JSONDecodeError): if response.status_code == 413: # Possibly split in two and retry? - logger.error("Shard upload failed: you are sending meta-data that is too large.") + logger.error( + "Shard upload failed: you are sending meta-data that is too large." + ) pbar.update(1) response.close() failed += shard_size @@ -1558,16 +1711,25 @@ def send_request(i): logger.debug( f"{self.identifier}: Connection failed for records {start_point}-{start_point + shard_size}, retrying." ) - failure_fraction = errors_504 / (failed + succeeded + errors_504) - if failure_fraction > 0.5 and errors_504 > shard_size * 3: + failure_fraction = errors_504 / ( + failed + succeeded + errors_504 + ) + if ( + failure_fraction > 0.5 + and errors_504 > shard_size * 3 + ): raise RuntimeError( f"{self.identifier}: Atlas is under high load and cannot ingest datums at this time. Please try again later." ) - new_submission = executor.submit(send_request, start_point) + new_submission = executor.submit( + send_request, start_point + ) futures[new_submission] = start_point response.close() else: - logger.error(f"{self.identifier}: Shard upload failed: {response}") + logger.error( + f"{self.identifier}: Shard upload failed: {response}" + ) failed += shard_size pbar.update(1) response.close() @@ -1592,8 +1754,13 @@ def send_request(i): else: logger.info("Upload succeeded.") - def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, num_workers: int = 10): - ''' + def update_maps( + self, + data: List[Dict], + embeddings: Optional[np.array] = None, + num_workers: int = 10, + ): + """ Utility method to update a project's maps by adding the given data. Args: @@ -1602,29 +1769,29 @@ def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, shard_size: Data is uploaded in parallel by many threads. Adjust the number of datums to upload by each worker. num_workers: The number of workers to use when sending data. - ''' + """ # Validate data - if self.modality == 'embedding' and embeddings is None: - msg = 'Please specify embeddings for updating an embedding project' + if self.modality == "embedding" and embeddings is None: + msg = "Please specify embeddings for updating an embedding project" raise ValueError(msg) - if self.modality == 'text' and embeddings is not None: - msg = 'Please dont specify embeddings for updating a text project' + if self.modality == "text" and embeddings is not None: + msg = "Please dont specify embeddings for updating a text project" raise ValueError(msg) if embeddings is not None and len(data) != embeddings.shape[0]: - msg = ( - 'Expected data and embeddings to be the same length but found lengths {} and {} respectively.'.format() - ) + msg = "Expected data and embeddings to be the same length but found lengths {} and {} respectively.".format() raise ValueError(msg) - shard_size = 2000 # TODO someone removed shard size from params and didn't update + shard_size = ( + 2000 # TODO someone removed shard size from params and didn't update + ) # Add new data logger.info("Uploading data to Nomic's neural database Atlas.") with tqdm(total=len(data) // shard_size) as pbar: for i in range(0, len(data), MAX_MEMORY_CHUNK): - if self.modality == 'embedding' and embeddings is not None: + if self.modality == "embedding": self._add_embeddings( embeddings=embeddings[i : i + MAX_MEMORY_CHUNK, :], data=data[i : i + MAX_MEMORY_CHUNK], @@ -1642,18 +1809,18 @@ def update_maps(self, data: List[Dict], embeddings: Optional[np.ndarray] = None, return self.update_indices() def update_indices(self, rebuild_topic_models: bool = False): - ''' + """ Rebuilds all maps in a dataset with the latest state dataset data state. Maps will not be rebuilt to reflect the additions, deletions or updates you have made to your data until this method is called. Args: rebuild_topic_models: (Default False) - If true, will create new topic models when updating these indices. - ''' + """ response = requests.post( self.atlas_api_path + "/v1/project/update_indices", headers=self.header, - json={'project_id': self.id, 'rebuild_topic_models': rebuild_topic_models}, + json={"project_id": self.id, "rebuild_topic_models": rebuild_topic_models}, ) logger.info(f"Updating maps in dataset `{self.identifier}`") From 3ab6ba792f3c6106b24450367eb4d715434509bf Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Fri, 19 Apr 2024 18:03:15 -0400 Subject: [PATCH 2/5] add index deletion, linting catastrophe. --- nomic/aws/sagemaker.py | 19 +-- nomic/data_operations.py | 191 ++++++++++++++-------------- nomic/dataset.py | 267 ++++++++++----------------------------- nomic/embed.py | 54 +++----- 4 files changed, 183 insertions(+), 348 deletions(-) diff --git a/nomic/aws/sagemaker.py b/nomic/aws/sagemaker.py index 3c44d63d..6b4df253 100644 --- a/nomic/aws/sagemaker.py +++ b/nomic/aws/sagemaker.py @@ -129,8 +129,6 @@ def embed_texts( region_name: str, task_type: str = "search_document", batch_size: int = 32, - dimensionality: int = 768, - binary: bool = False, ): """ Embed a list of texts using a sagemaker model endpoint. @@ -141,8 +139,6 @@ def embed_texts( region_name: AWS region sagemaker endpoint is in. task_type: The task type to use when embedding. batch_size: Size of each batch. Default is 32. - dimensionality: Number of dimensions to return. Options are (64, 128, 256, 512, 768). - binary: Whether to return binary embeddings. Returns: Dictionary with "embeddings" (python 2d list of floats), "model" (sagemaker endpoint used to generate embeddings). @@ -153,25 +149,12 @@ def embed_texts( return None texts = preprocess_texts(texts, task_type) - assert dimensionality in ( - 64, - 128, - 256, - 512, - 768, - ), f"Invalid number of dimensions: {dimensionality}" client = boto3.client("sagemaker-runtime", region_name=region_name) embeddings = [] for i in tqdm(range(0, len(texts), batch_size)): - batch = json.dumps( - { - "texts": texts[i : i + batch_size], - "binary": binary, - "dimensionality": dimensionality, - } - ) + batch = json.dumps({"texts": texts[i : i + batch_size]}) response = client.invoke_endpoint(EndpointName=sagemaker_endpoint, Body=batch, ContentType="application/json") embeddings.extend(parse_sagemaker_response(response)) diff --git a/nomic/data_operations.py b/nomic/data_operations.py index 533f0c7b..75ac4ffa 100644 --- a/nomic/data_operations.py +++ b/nomic/data_operations.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional, Tuple import numpy as np +import pandas import pandas as pd import pyarrow as pa import requests @@ -20,7 +21,6 @@ from pyarrow import feather, ipc from tqdm import tqdm - from .settings import EMBEDDING_PAGINATION_LIMIT from .utils import download_feather @@ -32,7 +32,7 @@ class AtlasMapDuplicates: your data. """ - def __init__(self, projection: "AtlasProjection"): # type: ignore + def __init__(self, projection: "AtlasProjection"): self.projection = projection self.id_field = self.projection.dataset.id_field try: @@ -46,7 +46,7 @@ def __init__(self, projection: "AtlasProjection"): # type: ignore self._tb: pa.Table = projection._fetch_tiles().select( [self.id_field, self.duplicate_field, self.cluster_field] ) - except pa.lib.ArrowInvalid as e: # type: ignore + except pa.lib.ArrowInvalid as e: raise ValueError("Duplicate detection has not yet been run on this map.") self.duplicate_field = self.duplicate_field.lstrip("_") self.cluster_field = self.cluster_field.lstrip("_") @@ -74,13 +74,13 @@ def deletion_candidates(self) -> List[str]: Returns: The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated. """ - dupes = self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], 'deletion candidate')) # type: ignore + dupes = self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')) return dupes.to_pylist() def __repr__(self) -> str: repr = f"===Atlas Duplicates for ({self.projection})\n" duplicate_count = len( - self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], 'deletion candidate')) # type: ignore + self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')) ) cluster_count = len(self.tb[self.cluster_field].value_counts()) repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n" @@ -92,7 +92,7 @@ class AtlasMapTopics: Atlas Topics State """ - def __init__(self, projection: "AtlasProjection"): # type: ignore + def __init__(self, projection: "AtlasProjection"): self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field @@ -107,11 +107,12 @@ def __init__(self, projection: "AtlasProjection"): # type: ignore # If using topic ids, fetch topic labels if 'int' in topic_fields[0]: new_topic_fields = [] - label_df = self.metadata[["topic_id", "depth", "topic_short_description"]] + metadata = self.metadata + label_df = metadata[["topic_id", "depth", "topic_short_description"]] for d in range(1, self.depth + 1): column = f"_topic_depth_{d}_int" topic_ids_to_label = self._tb[column].to_pandas().rename('topic_id') - topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge( + topic_ids_to_label = label_df[label_df["depth"] == d].merge( topic_ids_to_label, on='topic_id', how='right' ) new_column = f"_topic_depth_{d}" @@ -124,11 +125,11 @@ def __init__(self, projection: "AtlasProjection"): # type: ignore renamed_fields = [f'topic_depth_{i}' for i in range(1, self.depth + 1)] self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns([self.id_field] + renamed_fields) - except pa.lib.ArrowInvalid as e: # type: ignore + except pa.lib.ArrowInvalid as e: raise ValueError("Topic modeling has not yet been run on this map.") @property - def df(self) -> pd.DataFrame: + def df(self) -> pandas.DataFrame: """ A pandas DataFrame associating each datapoint on your map to their topics as each topic depth. """ @@ -144,7 +145,7 @@ def tb(self) -> pa.Table: return self._tb @property - def metadata(self) -> pd.DataFrame: + def metadata(self) -> pandas.DataFrame: """ Pandas DataFrame where each row gives metadata all map topics including: @@ -282,7 +283,7 @@ def get_topic_density(self, time_field: str, start: datetime, end: datetime): topic_densities[topic] += row[self.id_field + '_count'] return topic_densities - def vector_search_topics(self, queries: np.ndarray, k: int = 32, depth: int = 3) -> Dict: + def vector_search_topics(self, queries: np.array, k: int = 32, depth: int = 3) -> Dict: ''' Given an embedding, returns a normalized distribution over topics. @@ -377,7 +378,7 @@ class AtlasMapEmbeddings: """ - def __init__(self, projection: "AtlasProjection"): # type: ignore + def __init__(self, projection: "AtlasProjection"): self.projection = projection self.id_field = self.projection.dataset.id_field self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y']) @@ -417,7 +418,7 @@ def projected(self) -> pd.DataFrame: return self.df @property - def latent(self) -> np.ndarray: + def latent(self) -> np.array: """ High dimensional embeddings. @@ -433,21 +434,20 @@ def latent(self) -> np.ndarray: self._download_latent() all_embeddings = [] - for path in self.projection._tiles_in_order(coords_only=False): + for path in self.projection._tiles_in_order(): # double with-suffix to remove '.embeddings.feather' - if isinstance(path, Path): - files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather") - # Should there be more than 10, we need to sort by int values, not string values - sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1])) - if len(sortable) == 0: - raise FileNotFoundError( - "Could not find any embeddings for tile {}".format(path) - + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'." - ) - for file in sortable: - tb = feather.read_table(file, memory_map=True) - dims = tb['_embeddings'].type.list_size - all_embeddings.append(pa.compute.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims)) # type: ignore + files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather") + # Should there be more than 10, we need to sort by int values, not string values + sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1])) + if len(sortable) == 0: + raise FileNotFoundError( + "Could not find any embeddings for tile {}".format(path) + + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'." + ) + for file in sortable: + tb = feather.read_table(file, memory_map=True) + dims = tb['_embeddings'].type.list_size + all_embeddings.append(pc.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims)) return np.vstack(all_embeddings) def _download_latent(self): @@ -476,9 +476,7 @@ def _download_latent(self): last = tilename pbar.update(1) - def vector_search( - self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5 - ) -> Tuple[List, List]: + def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]: ''' Performs semantic vector search over data points on your map. If ids is specified, receive back the most similar data ids in latent vector space to your input ids. @@ -512,6 +510,8 @@ def vector_search( raise Exception("`queries` must be an instance of np.array.") if queries.shape[0] > max_queries: raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.") + + if queries is not None: if queries.ndim != 2: raise ValueError( 'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).' @@ -520,6 +520,7 @@ def vector_search( bytesio = io.BytesIO() np.save(bytesio, queries) + if queries is not None: response = requests.post( self.projection.dataset.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding", headers=self.projection.dataset.header, @@ -584,7 +585,7 @@ class AtlasMapTags: the associated pandas DataFrame. """ - def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): # type: ignore + def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field @@ -606,27 +607,28 @@ def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: tbs = [] all_quads = list(self.projection._tiles_in_order(coords_only=True)) for quad in tqdm(all_quads): - if isinstance(quad, Tuple): - quad_str = os.path.join(*[str(q) for q in quad]) - datum_id_filename = quad_str + "." + "datum_id" + ".feather" - path = self.projection.tile_destination / Path(datum_id_filename) - tb = feather.read_table(path, memory_map=True) - for tag in tags: - tag_definition_id = tag["tag_definition_id"] - tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" - path = self.projection.tile_destination / Path(tag_filename) - tag_tb = feather.read_table(path, memory_map=True) - bitmask = None - if "all_set" in tag_tb.column_names: - bool_v = tag_tb["all_set"][0].as_py() == True - bitmask = pa.array([bool_v] * len(tb), type=pa.bool_()) + quad_str = os.path.join(*[str(q) for q in quad]) + datum_id_filename = quad_str + "." + "datum_id" + ".feather" + path = self.projection.tile_destination / Path(datum_id_filename) + tb = feather.read_table(path, memory_map=True) + for tag in tags: + tag_definition_id = tag["tag_definition_id"] + tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" + path = self.projection.tile_destination / Path(tag_filename) + tag_tb = feather.read_table(path, memory_map=True) + bitmask = None + if "all_set" in tag_tb.column_names: + if tag_tb["all_set"][0].as_py() == True: + bitmask = pa.array([True] * len(tb), type=pa.bool_()) else: - bitmask = tag_tb["bitmask"] - tb = tb.append_column(tag["tag_name"], bitmask) - tbs.append(tb) + bitmask = pa.array([False] * len(tb), type=pa.bool_()) + else: + bitmask = tag_tb["bitmask"] + tb = tb.append_column(tag["tag_name"], bitmask) + tbs.append(tb) return pa.concat_tables(tbs).to_pandas() - def get_tags(self) -> List[Dict[str, str]]: + def get_tags(self) -> Dict[str, List[str]]: ''' Retrieves back all tags made in the web browser for a specific map. Each tag is a dictionary containing tag_name, tag_id, and metadata. @@ -706,25 +708,24 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): all_quads = list(self.projection._tiles_in_order(coords_only=True)) ordered_tag_paths = [] for quad in tqdm(all_quads): - if isinstance(quad, Tuple): - quad_str = os.path.join(*[str(q) for q in quad]) - filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" - path = self.projection.tile_destination / Path(filename) - download_attempt = 0 - download_success = False - while download_attempt < 3 and not download_success: - download_attempt += 1 - if not path.exists() or overwrite: - download_feather(root_url + filename, path, headers=self.dataset.header) - try: - ipc.open_file(path).schema - download_success = True - except pa.ArrowInvalid: - path.unlink(missing_ok=True) - - if not download_success: - raise Exception(f"Failed to download tag {tag_name}.") - ordered_tag_paths.append(path) + quad_str = os.path.join(*[str(q) for q in quad]) + filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" + path = self.projection.tile_destination / Path(filename) + download_attempt = 0 + download_success = False + while download_attempt < 3 and not download_success: + download_attempt += 1 + if not path.exists() or overwrite: + download_feather(root_url + filename, path, headers=self.dataset.header) + try: + ipc.open_file(path).schema + download_success = True + except pa.ArrowInvalid: + path.unlink(missing_ok=True) + + if not download_success: + raise Exception(f"Failed to download tag {tag_name}.") + ordered_tag_paths.append(path) return ordered_tag_paths def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): @@ -738,23 +739,22 @@ def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): # NOTE: This currently only gets triggered on `df` property all_quads = list(self.projection._tiles_in_order(coords_only=True)) for quad in tqdm(all_quads): - if isinstance(quad, Tuple): - quad_str = os.path.join(*[str(q) for q in quad]) - tile = self.projection.tile_destination / Path(quad_str) - tile_dir = tile.parent - if tile_dir.exists(): - tagged_files = tile_dir.glob('*_tag*') - for file in tagged_files: - tag_definition_id = file.name.split(".")[-2] - if tag_definition_id in tag_definition_ids: - try: - file.unlink() - except PermissionError: - print("Permission denied: unable to delete outdated tag file. Skipping") - return - except Exception as e: - print(f"Exception occurred when trying to delete outdated tag file: {e}. Skipping") - return + quad_str = os.path.join(*[str(q) for q in quad]) + tile = self.projection.tile_destination / Path(quad_str) + tile_dir = tile.parent + if tile_dir.exists(): + tagged_files = tile_dir.glob('*_tag*') + for file in tagged_files: + tag_definition_id = file.name.split(".")[-2] + if tag_definition_id in tag_definition_ids: + try: + file.unlink() + except PermissionError: + print("Permission denied: unable to delete outdated tag file. Skipping") + return + except Exception as e: + print(f"Exception occurred when trying to delete outdated tag file: {e}. Skipping") + return def add(self, ids: List[str], tags: List[str]): # ''' @@ -792,34 +792,35 @@ class AtlasMapData: you uploaded with your project. """ - def __init__(self, projection: "AtlasProjection", fields=None): # type: ignore + def __init__(self, projection: "AtlasProjection", fields=None): self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field + self._tb = None self.fields = fields try: # Run fetch_tiles first to guarantee existence of quad feather files self._basic_data: pa.Table = self.projection._fetch_tiles() sidecars = self._download_data(fields=fields) - self._tb = self._read_prefetched_tiles_with_sidecars(sidecars) + self._read_prefetched_tiles_with_sidecars(sidecars) - except pa.lib.ArrowInvalid as e: # type: ignore + except pa.lib.ArrowInvalid as e: raise ValueError("Failed to fetch tiles for this map") - def _read_prefetched_tiles_with_sidecars(self, additional_sidecars): + def _read_prefetched_tiles_with_sidecars(self, additional_sidecars=None): tbs = [] - root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) # type: ignore + root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) try: small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()]) except KeyError: small_sidecars = set([]) for path in self.projection._tiles_in_order(): - tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore + tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) for col in tb.column_names: if col[0] == "_": tb = tb.drop([col]) for sidecar_file in small_sidecars: - carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) # type: ignore + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) for big_sidecar in additional_sidecars: @@ -828,7 +829,7 @@ def _read_prefetched_tiles_with_sidecars(self, additional_sidecars): if big_sidecar != 'datum_id' else big_sidecar ) - carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) @@ -873,7 +874,7 @@ def _download_data(self, fields=None): return sidecars @property - def df(self) -> pd.DataFrame: + def df(self) -> pandas.DataFrame: """ A pandas DataFrame associating each datapoint on your map to their metadata. Converting to pandas DataFrame may materialize a large amount of data into memory. diff --git a/nomic/dataset.py b/nomic/dataset.py index 6bbd9a01..4c3939af 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -34,13 +34,7 @@ NomicTopicOptions, convert_pyarrow_schema_for_atlas, ) -from .data_operations import ( - AtlasMapData, - AtlasMapDuplicates, - AtlasMapEmbeddings, - AtlasMapTags, - AtlasMapTopics, -) +from .data_operations import AtlasMapData, AtlasMapDuplicates, AtlasMapEmbeddings, AtlasMapTags, AtlasMapTopics from .settings import * from .utils import assert_valid_project_id, get_object_size_in_bytes @@ -113,9 +107,7 @@ def _get_current_user(self): ) response = validate_api_http_response(response) if not response.status_code == 200: - raise ValueError( - "Your authorization token is no longer valid. Run `nomic login` to obtain a new one." - ) + raise ValueError("Your authorization token is no longer valid. Run `nomic login` to obtain a new one.") return response.json() @@ -130,9 +122,7 @@ def _validate_map_data_inputs(self, colorable_fields, id_field, data_sample): for field in colorable_fields: if field not in data_sample: - raise Exception( - f"Cannot color by field `{field}` as it is not present in the metadata." - ) + raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.") def _get_current_users_main_organization(self): """ @@ -149,10 +139,7 @@ def _get_current_users_main_organization(self): return organization for organization in user["organizations"]: - if ( - organization["user_id"] == user["sub"] - and organization["access_role"] == "OWNER" - ): + if organization["user_id"] == user["sub"] and organization["access_role"] == "OWNER": return organization return {} @@ -182,9 +169,7 @@ def _get_project_by_id(self, project_id: str): ) if response.status_code != 200: - raise Exception( - f"Could not access dataset with id {project_id}: {response.text}" - ) + raise Exception(f"Could not access dataset with id {project_id}: {response.text}") return response.json() @@ -270,9 +255,7 @@ def _get_index_job(self, job_id: str): return response.json() - def _validate_and_correct_arrow_upload( - self, data: pa.Table, project: "AtlasDataset" - ) -> pa.Table: + def _validate_and_correct_arrow_upload(self, data: pa.Table, project: "AtlasDataset") -> pa.Table: """ Private method. validates upload data against the dataset arrow schema, and associated other checks. @@ -304,8 +287,7 @@ def _validate_and_correct_arrow_upload( for col in data.column_names: if col.lower() in seen: raise ValueError( - f"Two different fields have the same lowercased name, `{col}`" - ": you must use unique column names." + f"Two different fields have the same lowercased name, `{col}`" ": you must use unique column names." ) seen.add(col.lower()) @@ -337,41 +319,27 @@ def _validate_and_correct_arrow_upload( logger.warning( f"Replacing {data[field.name].null_count} null values for field {field.name} with string 'null'. This behavior will change in a future version." ) - reformatted[field.name] = pc.fill_null( - reformatted[field.name], "null" - ) + reformatted[field.name] = pc.fill_null(reformatted[field.name], "null") if pc.any(pc.equal(pc.binary_length(reformatted[field.name]), 0)): - mask = pc.equal( - pc.binary_length(reformatted[field.name]), 0 - ).combine_chunks() + mask = pc.equal(pc.binary_length(reformatted[field.name]), 0).combine_chunks() assert pa.types.is_boolean(mask.type) - reformatted[field.name] = pc.replace_with_mask( - reformatted[field.name], mask, "null" - ) + reformatted[field.name] = pc.replace_with_mask(reformatted[field.name], mask, "null") for field in data.schema: if not field.name in reformatted: if field.name == "_embeddings": reformatted["_embeddings"] = data["_embeddings"] else: - logger.warning( - f"Field {field.name} present in data, but not found in table schema. Ignoring." - ) + logger.warning(f"Field {field.name} present in data, but not found in table schema. Ignoring.") data = pa.Table.from_pydict(reformatted, schema=project.schema) if project.meta["insert_update_delete_lock"]: - raise Exception( - "Project is currently indexing and cannot ingest new datums. Try again later." - ) + raise Exception("Project is currently indexing and cannot ingest new datums. Try again later.") # The following two conditions should never occur given the above, but just in case... - assert ( - project.id_field in data.column_names - ), f"Upload does not contain your specified id_field" + assert project.id_field in data.column_names, f"Upload does not contain your specified id_field" if not pa.types.is_string(data[project.id_field].type): - logger.warning( - f"id_field is not a string. Converting to string from {data[project.id_field].type}" - ) + logger.warning(f"id_field is not a string. Converting to string from {data[project.id_field].type}") data = data.drop([project.id_field]).append_column( project.id_field, data[project.id_field].cast(pa.string()) ) @@ -382,17 +350,15 @@ def _validate_and_correct_arrow_upload( continue raise ValueError("Metadata fields cannot start with _") if pc.max(pc.utf8_length(data[project.id_field])).as_py() > 36: - first_match = data.filter( - pc.greater(pc.utf8_length(data[project.id_field]), 36) - ).to_pylist()[0][project.id_field] + first_match = data.filter(pc.greater(pc.utf8_length(data[project.id_field]), 36)).to_pylist()[0][ + project.id_field + ] raise ValueError( f"The id_field {first_match} is greater than 36 characters. Atlas does not support id_fields longer than 36 characters." ) return data - def _get_organization( - self, organization_slug=None, organization_id=None - ) -> Tuple[str, str]: + def _get_organization(self, organization_slug=None, organization_id=None) -> Tuple[str, str]: """ Gets an organization by either its name or id. @@ -406,16 +372,12 @@ def _get_organization( """ if organization_slug is None: - if ( - organization_id is None - ): # default to current users organization (the one with their name) + if organization_id is None: # default to current users organization (the one with their name) organization = self._get_current_users_main_organization() organization_slug = organization["slug"] organization_id = organization["organization_id"] else: - raise NotImplementedError( - "Getting organization by a specific ID is not yet implemented." - ) + raise NotImplementedError("Getting organization by a specific ID is not yet implemented.") else: try: @@ -438,9 +400,7 @@ class AtlasIndex: the points in the index that you can browse online. """ - def __init__( - self, atlas_index_id, name, indexed_field, projections, dataset: AtlasDataset - ): + def __init__(self, atlas_index_id, name, indexed_field, projections, dataset: AtlasDataset): """Initializes an Atlas index. Atlas indices organize data and store views of the data as maps.""" self.id = atlas_index_id self.name = name @@ -508,8 +468,7 @@ def map_link(self): @property def _status(self): response = requests.get( - self.dataset.atlas_api_path - + f"/v1/project/index/job/progress/{self.atlas_index_id}", + self.dataset.atlas_api_path + f"/v1/project/index/job/progress/{self.atlas_index_id}", headers=self.dataset.header, ) if response.status_code != 200: @@ -586,9 +545,7 @@ def _repr_html_(self): def duplicates(self): """Duplicate detection state""" if self.dataset.is_locked: - raise Exception( - "Dataset is locked! Please wait until the dataset is unlocked to access duplicates." - ) + raise Exception("Dataset is locked! Please wait until the dataset is unlocked to access duplicates.") if self._duplicates is None: self._duplicates = AtlasMapDuplicates(self) return self._duplicates @@ -646,8 +603,7 @@ def schema(self): ) if self._schema is None: response = requests.get( - self.dataset.atlas_api_path - + f"/v1/project/projection/{self.projection_id}/schema", + self.dataset.atlas_api_path + f"/v1/project/projection/{self.projection_id}/schema", headers=self.dataset.header, ) if response.status_code != 200: @@ -680,24 +636,16 @@ def _fetch_tiles(self, overwrite: bool = True): return self._tile_data self._download_large_feather(overwrite=overwrite) tbs = [] - root = feather.read_table( - self.tile_destination / "0/0/0.feather", memory_map=True - ) + root = feather.read_table(self.tile_destination / "0/0/0.feather", memory_map=True) try: - sidecars = set( - [v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()] - ) + sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()]) except KeyError: sidecars = set([]) - sidecars |= set( - sidecar_name for (_, sidecar_name) in self._registered_sidecars() - ) + sidecars |= set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) for path in self._tiles_in_order(): tb = pa.feather.read_table(path, memory_map=True) for sidecar_file in sidecars: - carfile = pa.feather.read_table( - path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True - ) + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) @@ -728,9 +676,7 @@ def children(z, x, y): # Pop off the front, extend the back (breadth first traversal) while len(paths) > 0: z, x, y = paths.pop(0) - path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix( - ".feather" - ) + path = Path(self.tile_destination, str(z), str(x), str(y)).with_suffix(".feather") if path.exists(): if coords_only: yield (z, x, y) @@ -742,9 +688,7 @@ def children(z, x, y): def tile_destination(self): return Path("~/.nomic/cache", self.id).expanduser() - def _download_large_feather( - self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True - ): + def _download_large_feather(self, dest: Optional[Union[str, Path]] = None, overwrite: bool = True): """ Downloads the feather tree. Args: @@ -759,9 +703,7 @@ def _download_large_feather( root = f"{self.dataset.atlas_api_path}/v1/project/{self.dataset.id}/index/projection/{self.id}/quadtree/" all_quads = [] sidecars = None - registered_sidecars = set( - sidecar_name for (_, sidecar_name) in self._registered_sidecars() - ) + registered_sidecars = set(sidecar_name for (_, sidecar_name) in self._registered_sidecars()) while len(quads) > 0: rawquad = quads.pop(0) quad = rawquad + ".feather" @@ -791,9 +733,7 @@ def _download_large_feather( if sidecars is None and b"sidecars" in schema.metadata: # Grab just the filenames - sidecars = set( - [v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()] - ) + sidecars = set([v for k, v in json.loads(schema.metadata.get(b"sidecars")).items()]) elif sidecars is None: sidecars = set() if not "." in rawquad: @@ -847,9 +787,7 @@ class AtlasDataStream(AtlasClass): def __init__(self, name: Optional[str] = "contrastors"): super().__init__() if name != "contrastors": - raise NotImplementedError( - "Only contrastors datastream is currently supported" - ) + raise NotImplementedError("Only contrastors datastream is currently supported") self.name = name # TODO: add support for other datastreams @@ -888,9 +826,7 @@ def __init__( * **is_public** - Should this dataset be publicly accessible for viewing (read only). If False, only members of your Nomic organization can view. * **dataset_id** - An alternative way to load a dataset is by passing the dataset_id directly. This only works if a dataset exists. """ - assert ( - identifier is not None or dataset_id is not None - ), "You must pass a dataset identifier" + assert identifier is not None or dataset_id is not None, "You must pass a dataset identifier" super().__init__() @@ -914,9 +850,7 @@ def __init__( dataset_id = dataset["id"] if dataset_id is None: # if there is no existing project, make a new one. - if ( - unique_id_field is None - ): # if not all parameters are specified, we weren't trying to make a project + if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project raise ValueError(f"Dataset `{identifier}` does not exist.") # if modality is None: @@ -942,9 +876,7 @@ def delete(self): organization = self._get_current_users_main_organization() organization_slug = organization["slug"] - logger.info( - f"Deleting dataset `{self.slug}` from organization `{organization_slug}`" - ) + logger.info(f"Deleting dataset `{self.slug}` from organization `{organization_slug}`") self._delete_project_by_id(project_id=self.id) @@ -1065,9 +997,7 @@ def id_field(self) -> str: @property def created_timestamp(self) -> datetime: - return datetime.strptime( - self.meta["created_timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z" - ) + return datetime.strptime(self.meta["created_timestamp"], "%Y-%m-%dT%H:%M:%S.%f%z") @property def total_datums(self) -> int: @@ -1111,9 +1041,7 @@ def schema(self) -> Optional[pa.Schema]: if self._schema is not None: return self._schema if "schema" in self.meta and self.meta["schema"] is not None: - self._schema: pa.Schema = ipc.read_schema( - io.BytesIO(base64.b64decode(self.meta["schema"])) - ) + self._schema: pa.Schema = ipc.read_schema(io.BytesIO(base64.b64decode(self.meta["schema"]))) return self._schema return None @@ -1140,9 +1068,7 @@ def wait_for_dataset_lock(self): has_logged = True time.sleep(5) - def get_map( - self, name: str = None, atlas_index_id: str = None, projection_id: str = None - ) -> AtlasProjection: + def get_map(self, name: str = None, atlas_index_id: str = None, projection_id: str = None) -> AtlasProjection: """ Retrieves a map. @@ -1161,22 +1087,16 @@ def get_map( for index in indices: if index.id == atlas_index_id: if len(index.projections) == 0: - raise ValueError( - f"No map found under index with atlas_index_id='{atlas_index_id}'" - ) + raise ValueError(f"No map found under index with atlas_index_id='{atlas_index_id}'") return index.projections[0] - raise ValueError( - f"Could not find a map with atlas_index_id='{atlas_index_id}'" - ) + raise ValueError(f"Could not find a map with atlas_index_id='{atlas_index_id}'") if projection_id is not None: for index in indices: for projection in index.projections: if projection.id == projection_id: return projection - raise ValueError( - f"Could not find a map with projection_id='{atlas_index_id}'" - ) + raise ValueError(f"Could not find a map with projection_id='{atlas_index_id}'") if len(indices) == 0: raise ValueError("You have no maps built in your project") @@ -1259,9 +1179,7 @@ def create_index( colorable_fields = [] for field in self.dataset_fields: - if field not in [self.id_field, indexed_field] and not field.startswith( - "_" - ): + if field not in [self.id_field, indexed_field] and not field.startswith("_"): colorable_fields.append(field) if self.modality == "embedding": @@ -1278,9 +1196,7 @@ def create_index( "colorable_fields": colorable_fields, "model_hyperparameters": None, "nearest_neighbor_index": "HNSWIndex", - "nearest_neighbor_index_hyperparameters": json.dumps( - {"space": "l2", "ef_construction": 100, "M": 16} - ), + "nearest_neighbor_index_hyperparameters": json.dumps({"space": "l2", "ef_construction": 100, "M": 16}), "projection": "NomicProject", "projection_hyperparameters": json.dumps( { @@ -1323,14 +1239,10 @@ def create_index( ) if indexed_field is None: - raise Exception( - "You did not specify a field to index. Specify an 'indexed_field'." - ) + raise Exception("You did not specify a field to index. Specify an 'indexed_field'.") if indexed_field not in self.dataset_fields: - raise Exception( - f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}" - ) + raise Exception(f"Indexing on {indexed_field} not allowed. Valid options are: {self.dataset_fields}") build_template = { "project_id": self.id, @@ -1349,9 +1261,7 @@ def create_index( } ), "nearest_neighbor_index": "HNSWIndex", - "nearest_neighbor_index_hyperparameters": json.dumps( - {"space": "l2", "ef_construction": 100, "M": 16} - ), + "nearest_neighbor_index_hyperparameters": json.dumps({"space": "l2", "ef_construction": 100, "M": 16}), "projection": "NomicProject", "projection_hyperparameters": json.dumps( { @@ -1385,9 +1295,7 @@ def create_index( json=build_template, ) if response.status_code != 200: - logger.info( - "Create dataset failed with code: {}".format(response.status_code) - ) + logger.info("Create dataset failed with code: {}".format(response.status_code)) logger.info("Additional info: {}".format(response.text)) raise Exception(response.json()["detail"]) @@ -1412,9 +1320,7 @@ def create_index( if atlas_projection is None: logger.warning("Could not find a map being built for this dataset.") - logger.info( - f"Created map `{projection.name}` in dataset `{self.identifier}`: {projection.map_link}" - ) + logger.info(f"Created map `{projection.name}` in dataset `{self.identifier}`: {projection.map_link}") return projection def __repr__(self): @@ -1448,13 +1354,7 @@ def _repr_html_(self): return html def __str__(self): - return "\n".join( - [ - str(projection) - for index in self.indices - for projection in index.projections - ] - ) + return "\n".join([str(projection) for index in self.indices for projection in index.projections]) def get_data(self, ids: List[str]) -> List[Dict]: """ @@ -1471,9 +1371,7 @@ def get_data(self, ids: List[str]) -> List[Dict]: if not isinstance(ids, list): raise ValueError("You must specify a list of ids when getting data.") if isinstance(ids[0], list): - raise ValueError( - "You must specify a list of ids when getting data, not a nested list." - ) + raise ValueError("You must specify a list of ids when getting data, not a nested list.") response = requests.post( self.atlas_api_path + "/v1/project/data/get", headers=self.header, @@ -1523,9 +1421,7 @@ def add_data( embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints. pbar: (Optional). A tqdm progress bar to update. """ - if embeddings is not None or ( - isinstance(data, pa.Table) and "_embeddings" in data.column_names - ): + if embeddings is not None or (isinstance(data, pa.Table) and "_embeddings" in data.column_names): self._add_embeddings(data=data, embeddings=embeddings, pbar=pbar) else: self._add_text(data=data, pbar=pbar) @@ -1541,9 +1437,7 @@ def _add_text(self, data=Union[DataFrame, List[Dict], pa.Table], pbar=None): elif isinstance(data, list): data = pa.Table.from_pylist(data) elif not isinstance(data, pa.Table): - raise ValueError( - "Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table." - ) + raise ValueError("Data must be a pandas DataFrame, list of dictionaries, or a pyarrow Table.") self._add_data(data, pbar=pbar) def _add_embeddings( @@ -1567,9 +1461,7 @@ def _add_embeddings( """ assert type(embeddings) == np.ndarray, "Embeddings must be a NumPy array." assert len(embeddings.shape) == 2, "Embeddings must be a 2D NumPy array." - assert ( - len(data) == embeddings.shape[0] - ), "Data and embeddings must have the same number of rows." + assert len(data) == embeddings.shape[0], "Data and embeddings must have the same number of rows." assert len(data) > 0, "Data must have at least one row." tb: pa.Table @@ -1595,9 +1487,7 @@ def _add_embeddings( assert not np.isnan(embeddings).any(), "Embeddings must not contain NaN values." assert not np.isinf(embeddings).any(), "Embeddings must not contain Inf values." - pyarrow_embeddings = pa.FixedSizeListArray.from_arrays( - embeddings.reshape((-1)), embeddings.shape[1] - ) + pyarrow_embeddings = pa.FixedSizeListArray.from_arrays(embeddings.reshape((-1)), embeddings.shape[1]) data_with_embeddings = tb.append_column("_embeddings", pyarrow_embeddings) @@ -1648,9 +1538,7 @@ def send_request(i): data_shard = data.slice(i, shard_size) with io.BytesIO() as buffer: data_shard = data_shard.replace_schema_metadata({"project_id": self.id}) - feather.write_feather( - data_shard, buffer, compression="zstd", compression_level=6 - ) + feather.write_feather(data_shard, buffer, compression="zstd", compression_level=6) buffer.seek(0) response = requests.post( @@ -1669,26 +1557,18 @@ def send_request(i): succeeded = 0 errors_504 = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = { - executor.submit(send_request, i): i - for i in range(0, len(data), shard_size) - } + futures = {executor.submit(send_request, i): i for i in range(0, len(data), shard_size)} while futures: # check for status of the futures which are currently working - done, not_done = concurrent.futures.wait( - futures, return_when=concurrent.futures.FIRST_COMPLETED - ) + done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) # process any completed futures for future in done: response = future.result() if response.status_code != 200: try: logger.error(f"Shard upload failed: {response.text}") - if ( - "more datums exceeds your organization limit" - in response.json() - ): + if "more datums exceeds your organization limit" in response.json(): return False if "Project transaction lock is held" in response.json(): raise Exception( @@ -1699,9 +1579,7 @@ def send_request(i): except (requests.JSONDecodeError, json.decoder.JSONDecodeError): if response.status_code == 413: # Possibly split in two and retry? - logger.error( - "Shard upload failed: you are sending meta-data that is too large." - ) + logger.error("Shard upload failed: you are sending meta-data that is too large.") pbar.update(1) response.close() failed += shard_size @@ -1711,25 +1589,16 @@ def send_request(i): logger.debug( f"{self.identifier}: Connection failed for records {start_point}-{start_point + shard_size}, retrying." ) - failure_fraction = errors_504 / ( - failed + succeeded + errors_504 - ) - if ( - failure_fraction > 0.5 - and errors_504 > shard_size * 3 - ): + failure_fraction = errors_504 / (failed + succeeded + errors_504) + if failure_fraction > 0.5 and errors_504 > shard_size * 3: raise RuntimeError( f"{self.identifier}: Atlas is under high load and cannot ingest datums at this time. Please try again later." ) - new_submission = executor.submit( - send_request, start_point - ) + new_submission = executor.submit(send_request, start_point) futures[new_submission] = start_point response.close() else: - logger.error( - f"{self.identifier}: Shard upload failed: {response}" - ) + logger.error(f"{self.identifier}: Shard upload failed: {response}") failed += shard_size pbar.update(1) response.close() @@ -1781,12 +1650,12 @@ def update_maps( raise ValueError(msg) if embeddings is not None and len(data) != embeddings.shape[0]: - msg = "Expected data and embeddings to be the same length but found lengths {} and {} respectively.".format() + msg = ( + "Expected data and embeddings to be the same length but found lengths {} and {} respectively.".format() + ) raise ValueError(msg) - shard_size = ( - 2000 # TODO someone removed shard size from params and didn't update - ) + shard_size = 2000 # TODO someone removed shard size from params and didn't update # Add new data logger.info("Uploading data to Nomic's neural database Atlas.") with tqdm(total=len(data) // shard_size) as pbar: diff --git a/nomic/embed.py b/nomic/embed.py index 36a2a2ce..2c177ae9 100644 --- a/nomic/embed.py +++ b/nomic/embed.py @@ -6,7 +6,7 @@ import os import time from io import BytesIO -from typing import Any, List, Literal, Optional, Sequence, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union, overload import PIL import PIL.Image @@ -15,14 +15,13 @@ from .dataset import AtlasClass from .settings import * -embed4all_installed = True - try: from gpt4all import CancellationError, Embed4All except ImportError: - embed4all_installed = False + if not TYPE_CHECKING: + Embed4All = None -atlas_class: Optional[AtlasClass] = None +atlas_class = None MAX_TEXT_REQUEST_SIZE = 50 MAX_IMAGE_REQUEST_SIZE = 512 @@ -49,33 +48,26 @@ def request_backoff( max_retries=5, backoff_if=is_backoff_status_code, ): - response = callable() for attempt in range(max_retries + 1): + response = callable() if attempt == max_retries: return response if backoff_if(response.status_code): delay = init_backoff * (ratio**attempt) logging.info(f"server error, backing off for {int(delay)}s") time.sleep(delay) - response = callable() else: - break - return response + return response def text_api_request( - texts: List[str], model: str, task_type: str, dimensionality: Optional[int] = None, long_text_mode: str = "truncate" + texts: List[str], model: str, task_type: str, dimensionality: int = None, long_text_mode: str = "truncate" ): global atlas_class - - assert atlas_class is not None - text_api_url = atlas_class.atlas_api_path + "/v1/embedding/text" - text_api_header = atlas_class.header - response = request_backoff( lambda: requests.post( - text_api_url, - headers=text_api_header, + atlas_class.atlas_api_path + "/v1/embedding/text", + headers=atlas_class.header, json={ "texts": texts, "model": model, @@ -102,8 +94,6 @@ def text( long_text_mode: str = ..., inference_mode: Literal["remote"] = ..., ) -> dict[str, Any]: ... - - @overload def text( texts: list[str], @@ -116,8 +106,6 @@ def text( device: str | None = ..., **kwargs: Any, ) -> dict[str, Any]: ... - - @overload def text( texts: list[str], @@ -181,7 +169,7 @@ def text( raise TypeError(f"device argument cannot be used with inference_mode='remote'") if kwargs: raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}") - elif embed4all_installed is None: + elif Embed4All is None: raise RuntimeError( f"The 'gpt4all' package is required for local inference. Suggestion: `pip install \"nomic[local]\"`", ) @@ -198,7 +186,7 @@ def text( device=device, **kwargs, ) - except CancellationError: # type: ignore + except CancellationError: pass # dynamic mode chose to use Atlas, fall through return _text_atlas(texts, model, task_type, dimensionality, long_text_mode) @@ -249,15 +237,15 @@ def _text_atlas( return combined -_embed4all: Optional[Embed4All] = None -_embed4all_kwargs: Optional[dict[str, Any]] = None +_embed4all: Embed4All | None = None +_embed4all_kwargs: dict[str, Any] | None = None def _text_embed4all( texts: list[str], model: str, task_type: str, - dimensionality: Optional[int], + dimensionality: int | None, long_text_mode: str, dynamic_mode: bool, **kwargs: Any, @@ -276,7 +264,7 @@ def _text_embed4all( if _embed4all is None or _embed4all.gpt4all.config["filename"] != g4a_model or _embed4all_kwargs != kwargs: if _embed4all is not None: _embed4all.close() - _embed4all = Embed4All(g4a_model, **kwargs) # type: ignore + _embed4all = Embed4All(g4a_model, **kwargs) _embed4all_kwargs = kwargs def cancel_cb(batch_sizes: list[int], backend: str) -> bool: @@ -309,15 +297,10 @@ def free_embedding_model() -> None: def image_api_request(images: List[Tuple[str, bytes]], model: str = 'nomic-embed-vision-v1'): global atlas_class - - assert atlas_class is not None - atlas_url = atlas_class.atlas_api_path - atlas_header = atlas_class.header - response = request_backoff( lambda: requests.post( - atlas_url + "/v1/embedding/image", - headers=atlas_header, + atlas_class.atlas_api_path + "/v1/embedding/image", + headers=atlas_class.header, data={"model": model}, files=images, ) @@ -340,7 +323,7 @@ def resize_pil(img): return img -def images(images: Sequence[Union[str, PIL.Image.Image]], model: str = 'nomic-embed-vision-v1'): +def images(images: Iterable[Union[str, PIL.Image.Image]], model: str = 'nomic-embed-vision-v1'): """ Generates embeddings for the given images. @@ -368,7 +351,6 @@ def images(images: Sequence[Union[str, PIL.Image.Image]], model: str = 'nomic-em img = resize_pil(PIL.Image.open(image)) buffered = BytesIO() img.save(buffered, format="JPEG") - image_batch.append(("images", buffered.getvalue())) elif isinstance(image, PIL.Image.Image): From 73cad6e38b27aba98654baca860482f5d9510e9b Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Thu, 25 Apr 2024 15:04:08 -0400 Subject: [PATCH 3/5] linting --- nomic/data_operations.py | 314 ++++++++++++++++++++++++++------------- 1 file changed, 213 insertions(+), 101 deletions(-) diff --git a/nomic/data_operations.py b/nomic/data_operations.py index 75ac4ffa..60120337 100644 --- a/nomic/data_operations.py +++ b/nomic/data_operations.py @@ -37,10 +37,18 @@ def __init__(self, projection: "AtlasProjection"): self.id_field = self.projection.dataset.id_field try: duplicate_fields = [ - field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field + field + for field in projection._fetch_tiles().column_names + if "_duplicate_class" in field + ] + cluster_fields = [ + field + for field in projection._fetch_tiles().column_names + if "_cluster" in field ] - cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field] - assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map." + assert ( + len(duplicate_fields) > 0 + ), "Duplicate detection has not yet been run on this map." self.duplicate_field = duplicate_fields[0] self.cluster_field = cluster_fields[0] self._tb: pa.Table = projection._fetch_tiles().select( @@ -50,7 +58,9 @@ def __init__(self, projection: "AtlasProjection"): raise ValueError("Duplicate detection has not yet been run on this map.") self.duplicate_field = self.duplicate_field.lstrip("_") self.cluster_field = self.cluster_field.lstrip("_") - self._tb = self._tb.rename_columns([self.id_field, self.duplicate_field, self.cluster_field]) + self._tb = self._tb.rename_columns( + [self.id_field, self.duplicate_field, self.cluster_field] + ) @property def df(self) -> pd.DataFrame: @@ -74,13 +84,17 @@ def deletion_candidates(self) -> List[str]: Returns: The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated. """ - dupes = self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')) + dupes = self.tb[self.id_field].filter( + pc.equal(self.tb[self.duplicate_field], "deletion candidate") + ) return dupes.to_pylist() def __repr__(self) -> str: repr = f"===Atlas Duplicates for ({self.projection})\n" duplicate_count = len( - self.tb[self.id_field].filter(pc.equal(self.tb[self.duplicate_field], 'deletion candidate')) + self.tb[self.id_field].filter( + pc.equal(self.tb[self.duplicate_field], "deletion candidate") + ) ) cluster_count = len(self.tb[self.cluster_field].value_counts()) repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n" @@ -101,29 +115,38 @@ def __init__(self, projection: "AtlasProjection"): try: self._tb: pa.Table = projection._fetch_tiles() - topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")] + topic_fields = [ + column + for column in self._tb.column_names + if column.startswith("_topic_depth_") + ] self.depth = len(topic_fields) # If using topic ids, fetch topic labels - if 'int' in topic_fields[0]: + if "int" in topic_fields[0]: new_topic_fields = [] metadata = self.metadata label_df = metadata[["topic_id", "depth", "topic_short_description"]] for d in range(1, self.depth + 1): column = f"_topic_depth_{d}_int" - topic_ids_to_label = self._tb[column].to_pandas().rename('topic_id') + topic_ids_to_label = self._tb[column].to_pandas().rename("topic_id") topic_ids_to_label = label_df[label_df["depth"] == d].merge( - topic_ids_to_label, on='topic_id', how='right' + topic_ids_to_label, on="topic_id", how="right" ) new_column = f"_topic_depth_{d}" self._tb = self._tb.append_column( - new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"]) + new_column, + pa.Array.from_pandas( + topic_ids_to_label["topic_short_description"] + ), ) new_topic_fields.append(new_column) topic_fields = new_topic_fields - renamed_fields = [f'topic_depth_{i}' for i in range(1, self.depth + 1)] - self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns([self.id_field] + renamed_fields) + renamed_fields = [f"topic_depth_{i}" for i in range(1, self.depth + 1)] + self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns( + [self.id_field] + renamed_fields + ) except pa.lib.ArrowInvalid as e: raise ValueError("Topic modeling has not yet been run on this map.") @@ -159,14 +182,16 @@ def metadata(self) -> pandas.DataFrame: response = requests.get( self.projection.dataset.atlas_api_path + "/v1/project/{}/index/projection/{}".format( - self.projection.dataset.meta['id'], self.projection.projection_id + self.projection.dataset.meta["id"], self.projection.projection_id ), headers=self.projection.dataset.header, ) - topics = json.loads(response.text)['topic_models'][0]['features'] - topic_data = [e['properties'] for e in topics] + topics = json.loads(response.text)["topic_models"][0]["features"] + topic_data = [e["properties"] for e in topics] topic_data = pd.DataFrame(topic_data) - column_list = [(f"_topic_depth_{i}", f"topic_depth_{i}") for i in range(1, self.depth + 1)] + column_list = [ + (f"_topic_depth_{i}", f"topic_depth_{i}") for i in range(1, self.depth + 1) + ] column_list.append(("topic", "topic_id")) topic_data = topic_data.rename(columns=dict(column_list)) self._metadata = topic_data @@ -196,8 +221,13 @@ def hierarchy(self) -> Dict: # list of subtopics for the topic at the previous depth for topic_index in range(len(topics) - 1): # depth is index + 1 - if topics[topic_index + 1] not in topic_hierarchy[(topics[topic_index], topic_index + 1)]: - topic_hierarchy[(topics[topic_index], topic_index + 1)].append(topics[topic_index + 1]) + if ( + topics[topic_index + 1] + not in topic_hierarchy[(topics[topic_index], topic_index + 1)] + ): + topic_hierarchy[(topics[topic_index], topic_index + 1)].append( + topics[topic_index + 1] + ) self._hierarchy = dict(topic_hierarchy) return self._hierarchy @@ -221,7 +251,9 @@ def group_by_topic(self, topic_depth: int = 1) -> List[Dict]: datum_id_col = self.dataset.meta["unique_id_field"] df = self.df - topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict() + topic_datum_dict = ( + df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict() + ) topic_df = self.metadata hierarchy = self.hierarchy result = [] @@ -238,18 +270,20 @@ def group_by_topic(self, topic_depth: int = 1) -> List[Dict]: if (topic_label, topic_depth) in hierarchy: subtopics = hierarchy[(topic_label, topic_depth)] result_dict["subtopics"] = subtopics - result_dict["subtopic_ids"] = topic_df[topic_df["topic_short_description"].isin(subtopics)][ - "topic_id" - ].tolist() + result_dict["subtopic_ids"] = topic_df[ + topic_df["topic_short_description"].isin(subtopics) + ]["topic_id"].tolist() result_dict["topic_id"] = topic_metadata["topic_id"].item() result_dict["topic_short_description"] = topic_label - result_dict["topic_long_description"] = topic_metadata["topic_description"].item() + result_dict["topic_long_description"] = topic_metadata[ + "topic_description" + ].item() result_dict["datum_ids"] = datum_ids result.append(result_dict) return result def get_topic_density(self, time_field: str, start: datetime, end: datetime): - ''' + """ Computes the density/frequency of topics in a given interval of a timestamp field. Useful for answering questions such as: @@ -263,10 +297,12 @@ def get_topic_density(self, time_field: str, start: datetime, end: datetime): Returns: A list of `{topic, count}` dictionaries, sorted from largest count to smallest count. - ''' + """ data = AtlasMapData(self.projection, fields=[time_field]) time_data = data._tb.select([self.id_field, time_field]) - merged_tb = self._tb.join(time_data, self.id_field, join_type="inner").combine_chunks() + merged_tb = self._tb.join( + time_data, self.id_field, join_type="inner" + ).combine_chunks() del time_data # free up memory @@ -274,17 +310,23 @@ def get_topic_density(self, time_field: str, start: datetime, end: datetime): merged_tb = merged_tb.filter(expr) topic_densities = {} for depth in range(1, self.depth + 1): - topic_column = f'topic_depth_{depth}' - topic_counts = merged_tb.group_by(topic_column).aggregate([(self.id_field, "count")]).to_pandas() + topic_column = f"topic_depth_{depth}" + topic_counts = ( + merged_tb.group_by(topic_column) + .aggregate([(self.id_field, "count")]) + .to_pandas() + ) for _, row in topic_counts.iterrows(): topic = row[topic_column] if topic not in topic_densities: topic_densities[topic] = 0 - topic_densities[topic] += row[self.id_field + '_count'] + topic_densities[topic] += row[self.id_field + "_count"] return topic_densities - def vector_search_topics(self, queries: np.array, k: int = 32, depth: int = 3) -> Dict: - ''' + def vector_search_topics( + self, queries: np.array, k: int = 32, depth: int = 3 + ) -> Dict: + """ Given an embedding, returns a normalized distribution over topics. Useful for answering the questions such as: @@ -299,11 +341,11 @@ def vector_search_topics(self, queries: np.array, k: int = 32, depth: int = 3) - Returns: A dict mapping `{topic: posterior probability}` for each query. - ''' + """ if queries.ndim != 2: raise ValueError( - 'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).' + "Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d)." ) bytesio = io.BytesIO() @@ -313,10 +355,10 @@ def vector_search_topics(self, queries: np.array, k: int = 32, depth: int = 3) - self.dataset.atlas_api_path + "/v1/project/data/get/embedding/topic", headers=self.dataset.header, json={ - 'atlas_index_id': self.projection.atlas_index_id, - 'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'), - 'k': k, - 'depth': depth, + "atlas_index_id": self.projection.atlas_index_id, + "queries": base64.b64encode(bytesio.getvalue()).decode("utf-8"), + "k": k, + "depth": depth, }, ) if response.status_code != 200: @@ -381,7 +423,7 @@ class AtlasMapEmbeddings: def __init__(self, projection: "AtlasProjection"): self.projection = projection self.id_field = self.projection.dataset.id_field - self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y']) + self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"]) self.dataset = projection.dataset self._latent = None @@ -436,9 +478,13 @@ def latent(self) -> np.array: for path in self.projection._tiles_in_order(): # double with-suffix to remove '.embeddings.feather' - files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather") + files = path.parent.glob( + path.with_suffix("").stem + "-*.embeddings.feather" + ) # Should there be more than 10, we need to sort by int values, not string values - sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1])) + sortable = sorted( + files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]) + ) if len(sortable) == 0: raise FileNotFoundError( "Could not find any embeddings for tile {}".format(path) @@ -446,8 +492,10 @@ def latent(self) -> np.array: ) for file in sortable: tb = feather.read_table(file, memory_map=True) - dims = tb['_embeddings'].type.list_size - all_embeddings.append(pc.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims)) + dims = tb["_embeddings"].type.list_size + all_embeddings.append( + pc.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims) + ) return np.vstack(all_embeddings) def _download_latent(self): @@ -456,28 +504,41 @@ def _download_latent(self): """ logger.warning("Downloading latent embeddings of all datapoints.") limit = 10_000 - route = self.projection.dataset.atlas_api_path + '/v1/project/data/get/embedding/paged' + route = ( + self.projection.dataset.atlas_api_path + + "/v1/project/data/get/embedding/paged" + ) last = None with tqdm(total=self.dataset.total_datums // limit) as pbar: while True: - params = {'projection_id': self.projection.id, "last_file": last, "page_size": limit} - r = requests.post(route, headers=self.projection.dataset.header, json=params) + params = { + "projection_id": self.projection.id, + "last_file": last, + "page_size": limit, + } + r = requests.post( + route, headers=self.projection.dataset.header, json=params + ) if r.status_code == 204: # Download complete! break fin = BytesIO(r.content) tb = feather.read_table(fin, memory_map=True) - tilename = tb.schema.metadata[b'tile'].decode("utf-8") - dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather") + tilename = tb.schema.metadata[b"tile"].decode("utf-8") + dest = (self.projection.tile_destination / tilename).with_suffix( + ".embeddings.feather" + ) dest.parent.mkdir(parents=True, exist_ok=True) feather.write_feather(tb, dest) last = tilename pbar.update(1) - def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]: - ''' + def vector_search( + self, queries: np.array = None, ids: List[str] = None, k: int = 5 + ) -> Dict[str, List]: + """ Performs semantic vector search over data points on your map. If ids is specified, receive back the most similar data ids in latent vector space to your input ids. If queries is specified, receive back the data ids with representations most similar to the query vectors. @@ -492,29 +553,37 @@ def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int A tuple with two elements containing the following information: neighbors: A set of ids corresponding to the nearest neighbors of each query distances: A set of distances between each query and its neighbors. - ''' + """ if queries is None and ids is None: - raise ValueError('You must specify either a list of datum `ids` or NumPy array of `queries` but not both.') + raise ValueError( + "You must specify either a list of datum `ids` or NumPy array of `queries` but not both." + ) max_k = 128 max_queries = 256 if k > max_k: - raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower") + raise Exception( + f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower" + ) if ids is not None: if len(ids) > max_queries: - raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.") + raise Exception( + f"Max ids per query is {max_queries}. You sent {len(ids)}." + ) if queries is not None: if not isinstance(queries, np.ndarray): raise Exception("`queries` must be an instance of np.array.") if queries.shape[0] > max_queries: - raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.") + raise Exception( + f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}." + ) if queries is not None: if queries.ndim != 2: raise ValueError( - 'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).' + "Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d)." ) bytesio = io.BytesIO() @@ -522,33 +591,41 @@ def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int if queries is not None: response = requests.post( - self.projection.dataset.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding", + self.projection.dataset.atlas_api_path + + "/v1/project/data/get/nearest_neighbors/by_embedding", headers=self.projection.dataset.header, json={ - 'atlas_index_id': self.projection.atlas_index_id, - 'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'), - 'k': k, + "atlas_index_id": self.projection.atlas_index_id, + "queries": base64.b64encode(bytesio.getvalue()).decode("utf-8"), + "k": k, }, ) else: response = requests.post( - self.projection.dataset.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id", + self.projection.dataset.atlas_api_path + + "/v1/project/data/get/nearest_neighbors/by_id", headers=self.projection.dataset.header, - json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k}, + json={ + "atlas_index_id": self.projection.atlas_index_id, + "datum_ids": ids, + "k": k, + }, ) if response.status_code == 500: - raise Exception('Cannot perform vector search on your map at this time. Try again later.') + raise Exception( + "Cannot perform vector search on your map at this time. Try again later." + ) if response.status_code != 200: raise Exception(response.text) response = response.json() - return response['neighbors'], response['distances'] + return response["neighbors"], response["distances"] def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]: - ''' + """ Deprecated in favor of `map.embeddings.latent`. Iterate through embeddings of your datums. @@ -556,12 +633,14 @@ def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]: Returns: An iterable mapping datum ids to their embeddings. - ''' + """ - raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.") + raise DeprecationWarning( + "Deprecated as of June 2023. Iterate `map.embeddings.latent`." + ) def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool: - ''' + """ Deprecated in favor of `map.embeddings.latent`. Downloads embeddings to the specified save_directory. @@ -572,8 +651,10 @@ def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bo True on success - ''' - raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.") + """ + raise DeprecationWarning( + "Deprecated as of June 2023. Use `map.embeddings.latent`." + ) def __repr__(self) -> str: return str(self.df) @@ -585,7 +666,9 @@ class AtlasMapTags: the associated pandas DataFrame. """ - def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): + def __init__( + self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False + ): self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field @@ -595,9 +678,9 @@ def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = @property def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: - ''' + """ Pandas DataFrame mapping each data point to its tags. - ''' + """ tags = self.get_tags() tag_definition_ids = [tag["tag_definition_id"] for tag in tags] if self.auto_cleanup: @@ -629,34 +712,38 @@ def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: return pa.concat_tables(tbs).to_pandas() def get_tags(self) -> Dict[str, List[str]]: - ''' + """ Retrieves back all tags made in the web browser for a specific map. Each tag is a dictionary containing tag_name, tag_id, and metadata. Returns: A list of tags a user has created for projection. - ''' + """ tags = requests.get( - self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all', + self.dataset.atlas_api_path + "/v1/project/projection/tags/get/all", headers=self.dataset.header, - params={'project_id': self.dataset.id, 'projection_id': self.projection.id, 'include_dsl_rule': False}, + params={ + "project_id": self.dataset.id, + "projection_id": self.projection.id, + "include_dsl_rule": False, + }, ).json() keep_tags = [] for tag in tags: is_complete = requests.get( - self.dataset.atlas_api_path + '/v1/project/projection/tags/status', + self.dataset.atlas_api_path + "/v1/project/projection/tags/status", headers=self.dataset.header, params={ - 'project_id': self.dataset.id, - 'tag_id': tag["tag_id"], + "project_id": self.dataset.id, + "tag_id": tag["tag_id"], }, - ).json()['is_complete'] + ).json()["is_complete"] if is_complete: keep_tags.append(tag) return keep_tags def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): - ''' + """ Returns the datum ids in a given tag. Args: @@ -665,7 +752,7 @@ def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): Returns: List of datum ids. - ''' + """ ordered_tag_paths = self._download_tag(tag_name, overwrite=overwrite) datum_ids = [] for path in ordered_tag_paths: @@ -681,7 +768,11 @@ def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): # filter on rows try: tb = tb.append_column(self.id_field, tile_tb[self.id_field]) - datum_ids.extend(tb.filter(pc.field("bitmask") == True)[self.id_field].to_pylist()) + datum_ids.extend( + tb.filter(pc.field("bitmask") == True)[ + self.id_field + ].to_pylist() + ) except Exception as e: raise Exception(f"Failed to fetch datums in tag. {e}") return datum_ids @@ -716,7 +807,9 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): while download_attempt < 3 and not download_success: download_attempt += 1 if not path.exists() or overwrite: - download_feather(root_url + filename, path, headers=self.dataset.header) + download_feather( + root_url + filename, path, headers=self.dataset.header + ) try: ipc.open_file(path).schema download_success = True @@ -729,13 +822,13 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): return ordered_tag_paths def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): - ''' + """ Attempts to remove outdated tag files based on tag definition ids. Any tag with a definition not in tag_definition_ids will be deleted. Args: tag_definition_ids: A list of tag definition ids to keep. - ''' + """ # NOTE: This currently only gets triggered on `df` property all_quads = list(self.projection._tiles_in_order(coords_only=True)) for quad in tqdm(all_quads): @@ -743,17 +836,21 @@ def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): tile = self.projection.tile_destination / Path(quad_str) tile_dir = tile.parent if tile_dir.exists(): - tagged_files = tile_dir.glob('*_tag*') + tagged_files = tile_dir.glob("*_tag*") for file in tagged_files: tag_definition_id = file.name.split(".")[-2] if tag_definition_id in tag_definition_ids: try: file.unlink() except PermissionError: - print("Permission denied: unable to delete outdated tag file. Skipping") + print( + "Permission denied: unable to delete outdated tag file. Skipping" + ) return except Exception as e: - print(f"Exception occurred when trying to delete outdated tag file: {e}. Skipping") + print( + f"Exception occurred when trying to delete outdated tag file: {e}. Skipping" + ) return def add(self, ids: List[str], tags: List[str]): @@ -809,9 +906,13 @@ def __init__(self, projection: "AtlasProjection", fields=None): def _read_prefetched_tiles_with_sidecars(self, additional_sidecars=None): tbs = [] - root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) + root = feather.read_table( + self.projection.tile_destination / Path("0/0/0.feather") + ) try: - small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()]) + small_sidecars = set( + [v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()] + ) except KeyError: small_sidecars = set([]) for path in self.projection._tiles_in_order(): @@ -820,16 +921,22 @@ def _read_prefetched_tiles_with_sidecars(self, additional_sidecars=None): if col[0] == "_": tb = tb.drop([col]) for sidecar_file in small_sidecars: - carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) + carfile = pa.feather.read_table( + path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True + ) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) for big_sidecar in additional_sidecars: fname = ( - base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8") - if big_sidecar != 'datum_id' + base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode( + "utf-8" + ) + if big_sidecar != "datum_id" else big_sidecar ) - carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) + carfile = pa.feather.read_table( + path.parent / f"{path.stem}.{fname}.feather", memory_map=True + ) for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) @@ -855,11 +962,16 @@ def _download_data(self, fields=None): ] else: for field in sidecars: - assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields." - encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars] - if any(sidecar == 'datum_id' for (field, sidecar) in registered_sidecars): - sidecars.append('datum_id') - encoded_sidecars.append('datum_id') + assert ( + field in self.dataset.dataset_fields + ), f"Field {field} not found in dataset fields." + encoded_sidecars = [ + base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") + for sidecar in sidecars + ] + if any(sidecar == "datum_id" for (field, sidecar) in registered_sidecars): + sidecars.append("datum_id") + encoded_sidecars.append("datum_id") for quad in tqdm(all_quads): for encoded_colname in encoded_sidecars: From 002f33e3b812d7e77fbd26384813544279f93ae7 Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Thu, 25 Apr 2024 15:19:01 -0400 Subject: [PATCH 4/5] unborking linting --- nomic/data_operations.py | 471 +++++++++++++++------------------------ 1 file changed, 179 insertions(+), 292 deletions(-) diff --git a/nomic/data_operations.py b/nomic/data_operations.py index 60120337..533f0c7b 100644 --- a/nomic/data_operations.py +++ b/nomic/data_operations.py @@ -12,7 +12,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import numpy as np -import pandas import pandas as pd import pyarrow as pa import requests @@ -21,6 +20,7 @@ from pyarrow import feather, ipc from tqdm import tqdm + from .settings import EMBEDDING_PAGINATION_LIMIT from .utils import download_feather @@ -32,35 +32,25 @@ class AtlasMapDuplicates: your data. """ - def __init__(self, projection: "AtlasProjection"): + def __init__(self, projection: "AtlasProjection"): # type: ignore self.projection = projection self.id_field = self.projection.dataset.id_field try: duplicate_fields = [ - field - for field in projection._fetch_tiles().column_names - if "_duplicate_class" in field + field for field in projection._fetch_tiles().column_names if "_duplicate_class" in field ] - cluster_fields = [ - field - for field in projection._fetch_tiles().column_names - if "_cluster" in field - ] - assert ( - len(duplicate_fields) > 0 - ), "Duplicate detection has not yet been run on this map." + cluster_fields = [field for field in projection._fetch_tiles().column_names if "_cluster" in field] + assert len(duplicate_fields) > 0, "Duplicate detection has not yet been run on this map." self.duplicate_field = duplicate_fields[0] self.cluster_field = cluster_fields[0] self._tb: pa.Table = projection._fetch_tiles().select( [self.id_field, self.duplicate_field, self.cluster_field] ) - except pa.lib.ArrowInvalid as e: + except pa.lib.ArrowInvalid as e: # type: ignore raise ValueError("Duplicate detection has not yet been run on this map.") self.duplicate_field = self.duplicate_field.lstrip("_") self.cluster_field = self.cluster_field.lstrip("_") - self._tb = self._tb.rename_columns( - [self.id_field, self.duplicate_field, self.cluster_field] - ) + self._tb = self._tb.rename_columns([self.id_field, self.duplicate_field, self.cluster_field]) @property def df(self) -> pd.DataFrame: @@ -84,17 +74,13 @@ def deletion_candidates(self) -> List[str]: Returns: The ids for all data points which are semantic duplicates and are candidates for being deleted from the dataset. If you remove these data points from your dataset, your dataset will be semantically deduplicated. """ - dupes = self.tb[self.id_field].filter( - pc.equal(self.tb[self.duplicate_field], "deletion candidate") - ) + dupes = self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], 'deletion candidate')) # type: ignore return dupes.to_pylist() def __repr__(self) -> str: repr = f"===Atlas Duplicates for ({self.projection})\n" duplicate_count = len( - self.tb[self.id_field].filter( - pc.equal(self.tb[self.duplicate_field], "deletion candidate") - ) + self.tb[self.id_field].filter(pa.compute.equal(self.tb[self.duplicate_field], 'deletion candidate')) # type: ignore ) cluster_count = len(self.tb[self.cluster_field].value_counts()) repr += f"{duplicate_count} deletion candidates in {cluster_count} clusters\n" @@ -106,7 +92,7 @@ class AtlasMapTopics: Atlas Topics State """ - def __init__(self, projection: "AtlasProjection"): + def __init__(self, projection: "AtlasProjection"): # type: ignore self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field @@ -115,44 +101,34 @@ def __init__(self, projection: "AtlasProjection"): try: self._tb: pa.Table = projection._fetch_tiles() - topic_fields = [ - column - for column in self._tb.column_names - if column.startswith("_topic_depth_") - ] + topic_fields = [column for column in self._tb.column_names if column.startswith("_topic_depth_")] self.depth = len(topic_fields) # If using topic ids, fetch topic labels - if "int" in topic_fields[0]: + if 'int' in topic_fields[0]: new_topic_fields = [] - metadata = self.metadata - label_df = metadata[["topic_id", "depth", "topic_short_description"]] + label_df = self.metadata[["topic_id", "depth", "topic_short_description"]] for d in range(1, self.depth + 1): column = f"_topic_depth_{d}_int" - topic_ids_to_label = self._tb[column].to_pandas().rename("topic_id") - topic_ids_to_label = label_df[label_df["depth"] == d].merge( - topic_ids_to_label, on="topic_id", how="right" + topic_ids_to_label = self._tb[column].to_pandas().rename('topic_id') + topic_ids_to_label = pd.DataFrame(label_df[label_df["depth"] == d]).merge( + topic_ids_to_label, on='topic_id', how='right' ) new_column = f"_topic_depth_{d}" self._tb = self._tb.append_column( - new_column, - pa.Array.from_pandas( - topic_ids_to_label["topic_short_description"] - ), + new_column, pa.Array.from_pandas(topic_ids_to_label["topic_short_description"]) ) new_topic_fields.append(new_column) topic_fields = new_topic_fields - renamed_fields = [f"topic_depth_{i}" for i in range(1, self.depth + 1)] - self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns( - [self.id_field] + renamed_fields - ) + renamed_fields = [f'topic_depth_{i}' for i in range(1, self.depth + 1)] + self._tb = self._tb.select([self.id_field] + topic_fields).rename_columns([self.id_field] + renamed_fields) - except pa.lib.ArrowInvalid as e: + except pa.lib.ArrowInvalid as e: # type: ignore raise ValueError("Topic modeling has not yet been run on this map.") @property - def df(self) -> pandas.DataFrame: + def df(self) -> pd.DataFrame: """ A pandas DataFrame associating each datapoint on your map to their topics as each topic depth. """ @@ -168,7 +144,7 @@ def tb(self) -> pa.Table: return self._tb @property - def metadata(self) -> pandas.DataFrame: + def metadata(self) -> pd.DataFrame: """ Pandas DataFrame where each row gives metadata all map topics including: @@ -182,16 +158,14 @@ def metadata(self) -> pandas.DataFrame: response = requests.get( self.projection.dataset.atlas_api_path + "/v1/project/{}/index/projection/{}".format( - self.projection.dataset.meta["id"], self.projection.projection_id + self.projection.dataset.meta['id'], self.projection.projection_id ), headers=self.projection.dataset.header, ) - topics = json.loads(response.text)["topic_models"][0]["features"] - topic_data = [e["properties"] for e in topics] + topics = json.loads(response.text)['topic_models'][0]['features'] + topic_data = [e['properties'] for e in topics] topic_data = pd.DataFrame(topic_data) - column_list = [ - (f"_topic_depth_{i}", f"topic_depth_{i}") for i in range(1, self.depth + 1) - ] + column_list = [(f"_topic_depth_{i}", f"topic_depth_{i}") for i in range(1, self.depth + 1)] column_list.append(("topic", "topic_id")) topic_data = topic_data.rename(columns=dict(column_list)) self._metadata = topic_data @@ -221,13 +195,8 @@ def hierarchy(self) -> Dict: # list of subtopics for the topic at the previous depth for topic_index in range(len(topics) - 1): # depth is index + 1 - if ( - topics[topic_index + 1] - not in topic_hierarchy[(topics[topic_index], topic_index + 1)] - ): - topic_hierarchy[(topics[topic_index], topic_index + 1)].append( - topics[topic_index + 1] - ) + if topics[topic_index + 1] not in topic_hierarchy[(topics[topic_index], topic_index + 1)]: + topic_hierarchy[(topics[topic_index], topic_index + 1)].append(topics[topic_index + 1]) self._hierarchy = dict(topic_hierarchy) return self._hierarchy @@ -251,9 +220,7 @@ def group_by_topic(self, topic_depth: int = 1) -> List[Dict]: datum_id_col = self.dataset.meta["unique_id_field"] df = self.df - topic_datum_dict = ( - df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict() - ) + topic_datum_dict = df.groupby(f"topic_depth_{topic_depth}")[datum_id_col].apply(set).to_dict() topic_df = self.metadata hierarchy = self.hierarchy result = [] @@ -270,20 +237,18 @@ def group_by_topic(self, topic_depth: int = 1) -> List[Dict]: if (topic_label, topic_depth) in hierarchy: subtopics = hierarchy[(topic_label, topic_depth)] result_dict["subtopics"] = subtopics - result_dict["subtopic_ids"] = topic_df[ - topic_df["topic_short_description"].isin(subtopics) - ]["topic_id"].tolist() + result_dict["subtopic_ids"] = topic_df[topic_df["topic_short_description"].isin(subtopics)][ + "topic_id" + ].tolist() result_dict["topic_id"] = topic_metadata["topic_id"].item() result_dict["topic_short_description"] = topic_label - result_dict["topic_long_description"] = topic_metadata[ - "topic_description" - ].item() + result_dict["topic_long_description"] = topic_metadata["topic_description"].item() result_dict["datum_ids"] = datum_ids result.append(result_dict) return result def get_topic_density(self, time_field: str, start: datetime, end: datetime): - """ + ''' Computes the density/frequency of topics in a given interval of a timestamp field. Useful for answering questions such as: @@ -297,12 +262,10 @@ def get_topic_density(self, time_field: str, start: datetime, end: datetime): Returns: A list of `{topic, count}` dictionaries, sorted from largest count to smallest count. - """ + ''' data = AtlasMapData(self.projection, fields=[time_field]) time_data = data._tb.select([self.id_field, time_field]) - merged_tb = self._tb.join( - time_data, self.id_field, join_type="inner" - ).combine_chunks() + merged_tb = self._tb.join(time_data, self.id_field, join_type="inner").combine_chunks() del time_data # free up memory @@ -310,23 +273,17 @@ def get_topic_density(self, time_field: str, start: datetime, end: datetime): merged_tb = merged_tb.filter(expr) topic_densities = {} for depth in range(1, self.depth + 1): - topic_column = f"topic_depth_{depth}" - topic_counts = ( - merged_tb.group_by(topic_column) - .aggregate([(self.id_field, "count")]) - .to_pandas() - ) + topic_column = f'topic_depth_{depth}' + topic_counts = merged_tb.group_by(topic_column).aggregate([(self.id_field, "count")]).to_pandas() for _, row in topic_counts.iterrows(): topic = row[topic_column] if topic not in topic_densities: topic_densities[topic] = 0 - topic_densities[topic] += row[self.id_field + "_count"] + topic_densities[topic] += row[self.id_field + '_count'] return topic_densities - def vector_search_topics( - self, queries: np.array, k: int = 32, depth: int = 3 - ) -> Dict: - """ + def vector_search_topics(self, queries: np.ndarray, k: int = 32, depth: int = 3) -> Dict: + ''' Given an embedding, returns a normalized distribution over topics. Useful for answering the questions such as: @@ -341,11 +298,11 @@ def vector_search_topics( Returns: A dict mapping `{topic: posterior probability}` for each query. - """ + ''' if queries.ndim != 2: raise ValueError( - "Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d)." + 'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).' ) bytesio = io.BytesIO() @@ -355,10 +312,10 @@ def vector_search_topics( self.dataset.atlas_api_path + "/v1/project/data/get/embedding/topic", headers=self.dataset.header, json={ - "atlas_index_id": self.projection.atlas_index_id, - "queries": base64.b64encode(bytesio.getvalue()).decode("utf-8"), - "k": k, - "depth": depth, + 'atlas_index_id': self.projection.atlas_index_id, + 'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'), + 'k': k, + 'depth': depth, }, ) if response.status_code != 200: @@ -420,10 +377,10 @@ class AtlasMapEmbeddings: """ - def __init__(self, projection: "AtlasProjection"): + def __init__(self, projection: "AtlasProjection"): # type: ignore self.projection = projection self.id_field = self.projection.dataset.id_field - self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, "x", "y"]) + self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y']) self.dataset = projection.dataset self._latent = None @@ -460,7 +417,7 @@ def projected(self) -> pd.DataFrame: return self.df @property - def latent(self) -> np.array: + def latent(self) -> np.ndarray: """ High dimensional embeddings. @@ -476,26 +433,21 @@ def latent(self) -> np.array: self._download_latent() all_embeddings = [] - for path in self.projection._tiles_in_order(): + for path in self.projection._tiles_in_order(coords_only=False): # double with-suffix to remove '.embeddings.feather' - files = path.parent.glob( - path.with_suffix("").stem + "-*.embeddings.feather" - ) - # Should there be more than 10, we need to sort by int values, not string values - sortable = sorted( - files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]) - ) - if len(sortable) == 0: - raise FileNotFoundError( - "Could not find any embeddings for tile {}".format(path) - + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'." - ) - for file in sortable: - tb = feather.read_table(file, memory_map=True) - dims = tb["_embeddings"].type.list_size - all_embeddings.append( - pc.list_flatten(tb["_embeddings"]).to_numpy().reshape(-1, dims) - ) + if isinstance(path, Path): + files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather") + # Should there be more than 10, we need to sort by int values, not string values + sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1])) + if len(sortable) == 0: + raise FileNotFoundError( + "Could not find any embeddings for tile {}".format(path) + + " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'." + ) + for file in sortable: + tb = feather.read_table(file, memory_map=True) + dims = tb['_embeddings'].type.list_size + all_embeddings.append(pa.compute.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims)) # type: ignore return np.vstack(all_embeddings) def _download_latent(self): @@ -504,41 +456,30 @@ def _download_latent(self): """ logger.warning("Downloading latent embeddings of all datapoints.") limit = 10_000 - route = ( - self.projection.dataset.atlas_api_path - + "/v1/project/data/get/embedding/paged" - ) + route = self.projection.dataset.atlas_api_path + '/v1/project/data/get/embedding/paged' last = None with tqdm(total=self.dataset.total_datums // limit) as pbar: while True: - params = { - "projection_id": self.projection.id, - "last_file": last, - "page_size": limit, - } - r = requests.post( - route, headers=self.projection.dataset.header, json=params - ) + params = {'projection_id': self.projection.id, "last_file": last, "page_size": limit} + r = requests.post(route, headers=self.projection.dataset.header, json=params) if r.status_code == 204: # Download complete! break fin = BytesIO(r.content) tb = feather.read_table(fin, memory_map=True) - tilename = tb.schema.metadata[b"tile"].decode("utf-8") - dest = (self.projection.tile_destination / tilename).with_suffix( - ".embeddings.feather" - ) + tilename = tb.schema.metadata[b'tile'].decode("utf-8") + dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather") dest.parent.mkdir(parents=True, exist_ok=True) feather.write_feather(tb, dest) last = tilename pbar.update(1) def vector_search( - self, queries: np.array = None, ids: List[str] = None, k: int = 5 - ) -> Dict[str, List]: - """ + self, queries: Optional[np.ndarray] = None, ids: Optional[List[str]] = None, k: int = 5 + ) -> Tuple[List, List]: + ''' Performs semantic vector search over data points on your map. If ids is specified, receive back the most similar data ids in latent vector space to your input ids. If queries is specified, receive back the data ids with representations most similar to the query vectors. @@ -553,79 +494,60 @@ def vector_search( A tuple with two elements containing the following information: neighbors: A set of ids corresponding to the nearest neighbors of each query distances: A set of distances between each query and its neighbors. - """ + ''' if queries is None and ids is None: - raise ValueError( - "You must specify either a list of datum `ids` or NumPy array of `queries` but not both." - ) + raise ValueError('You must specify either a list of datum `ids` or NumPy array of `queries` but not both.') max_k = 128 max_queries = 256 if k > max_k: - raise Exception( - f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower" - ) + raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower") if ids is not None: if len(ids) > max_queries: - raise Exception( - f"Max ids per query is {max_queries}. You sent {len(ids)}." - ) + raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.") if queries is not None: if not isinstance(queries, np.ndarray): raise Exception("`queries` must be an instance of np.array.") if queries.shape[0] > max_queries: - raise Exception( - f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}." - ) - - if queries is not None: + raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.") if queries.ndim != 2: raise ValueError( - "Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d)." + 'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).' ) bytesio = io.BytesIO() np.save(bytesio, queries) - if queries is not None: response = requests.post( - self.projection.dataset.atlas_api_path - + "/v1/project/data/get/nearest_neighbors/by_embedding", + self.projection.dataset.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding", headers=self.projection.dataset.header, json={ - "atlas_index_id": self.projection.atlas_index_id, - "queries": base64.b64encode(bytesio.getvalue()).decode("utf-8"), - "k": k, + 'atlas_index_id': self.projection.atlas_index_id, + 'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'), + 'k': k, }, ) else: response = requests.post( - self.projection.dataset.atlas_api_path - + "/v1/project/data/get/nearest_neighbors/by_id", + self.projection.dataset.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id", headers=self.projection.dataset.header, - json={ - "atlas_index_id": self.projection.atlas_index_id, - "datum_ids": ids, - "k": k, - }, + json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k}, ) if response.status_code == 500: - raise Exception( - "Cannot perform vector search on your map at this time. Try again later." - ) + raise Exception('Cannot perform vector search on your map at this time. Try again later.') if response.status_code != 200: raise Exception(response.text) response = response.json() - return response["neighbors"], response["distances"] + return response['neighbors'], response['distances'] def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]: - """ + ''' Deprecated in favor of `map.embeddings.latent`. Iterate through embeddings of your datums. @@ -633,14 +555,12 @@ def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]: Returns: An iterable mapping datum ids to their embeddings. - """ + ''' - raise DeprecationWarning( - "Deprecated as of June 2023. Iterate `map.embeddings.latent`." - ) + raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.") def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool: - """ + ''' Deprecated in favor of `map.embeddings.latent`. Downloads embeddings to the specified save_directory. @@ -651,10 +571,8 @@ def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bo True on success - """ - raise DeprecationWarning( - "Deprecated as of June 2023. Use `map.embeddings.latent`." - ) + ''' + raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.") def __repr__(self) -> str: return str(self.df) @@ -666,9 +584,7 @@ class AtlasMapTags: the associated pandas DataFrame. """ - def __init__( - self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False - ): + def __init__(self, projection: "AtlasProjection", auto_cleanup: Optional[bool] = False): # type: ignore self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field @@ -678,9 +594,9 @@ def __init__( @property def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: - """ + ''' Pandas DataFrame mapping each data point to its tags. - """ + ''' tags = self.get_tags() tag_definition_ids = [tag["tag_definition_id"] for tag in tags] if self.auto_cleanup: @@ -690,60 +606,55 @@ def df(self, overwrite: Optional[bool] = False) -> pd.DataFrame: tbs = [] all_quads = list(self.projection._tiles_in_order(coords_only=True)) for quad in tqdm(all_quads): - quad_str = os.path.join(*[str(q) for q in quad]) - datum_id_filename = quad_str + "." + "datum_id" + ".feather" - path = self.projection.tile_destination / Path(datum_id_filename) - tb = feather.read_table(path, memory_map=True) - for tag in tags: - tag_definition_id = tag["tag_definition_id"] - tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" - path = self.projection.tile_destination / Path(tag_filename) - tag_tb = feather.read_table(path, memory_map=True) - bitmask = None - if "all_set" in tag_tb.column_names: - if tag_tb["all_set"][0].as_py() == True: - bitmask = pa.array([True] * len(tb), type=pa.bool_()) + if isinstance(quad, Tuple): + quad_str = os.path.join(*[str(q) for q in quad]) + datum_id_filename = quad_str + "." + "datum_id" + ".feather" + path = self.projection.tile_destination / Path(datum_id_filename) + tb = feather.read_table(path, memory_map=True) + for tag in tags: + tag_definition_id = tag["tag_definition_id"] + tag_filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" + path = self.projection.tile_destination / Path(tag_filename) + tag_tb = feather.read_table(path, memory_map=True) + bitmask = None + if "all_set" in tag_tb.column_names: + bool_v = tag_tb["all_set"][0].as_py() == True + bitmask = pa.array([bool_v] * len(tb), type=pa.bool_()) else: - bitmask = pa.array([False] * len(tb), type=pa.bool_()) - else: - bitmask = tag_tb["bitmask"] - tb = tb.append_column(tag["tag_name"], bitmask) - tbs.append(tb) + bitmask = tag_tb["bitmask"] + tb = tb.append_column(tag["tag_name"], bitmask) + tbs.append(tb) return pa.concat_tables(tbs).to_pandas() - def get_tags(self) -> Dict[str, List[str]]: - """ + def get_tags(self) -> List[Dict[str, str]]: + ''' Retrieves back all tags made in the web browser for a specific map. Each tag is a dictionary containing tag_name, tag_id, and metadata. Returns: A list of tags a user has created for projection. - """ + ''' tags = requests.get( - self.dataset.atlas_api_path + "/v1/project/projection/tags/get/all", + self.dataset.atlas_api_path + '/v1/project/projection/tags/get/all', headers=self.dataset.header, - params={ - "project_id": self.dataset.id, - "projection_id": self.projection.id, - "include_dsl_rule": False, - }, + params={'project_id': self.dataset.id, 'projection_id': self.projection.id, 'include_dsl_rule': False}, ).json() keep_tags = [] for tag in tags: is_complete = requests.get( - self.dataset.atlas_api_path + "/v1/project/projection/tags/status", + self.dataset.atlas_api_path + '/v1/project/projection/tags/status', headers=self.dataset.header, params={ - "project_id": self.dataset.id, - "tag_id": tag["tag_id"], + 'project_id': self.dataset.id, + 'tag_id': tag["tag_id"], }, - ).json()["is_complete"] + ).json()['is_complete'] if is_complete: keep_tags.append(tag) return keep_tags def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): - """ + ''' Returns the datum ids in a given tag. Args: @@ -752,7 +663,7 @@ def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): Returns: List of datum ids. - """ + ''' ordered_tag_paths = self._download_tag(tag_name, overwrite=overwrite) datum_ids = [] for path in ordered_tag_paths: @@ -768,11 +679,7 @@ def get_datums_in_tag(self, tag_name: str, overwrite: Optional[bool] = False): # filter on rows try: tb = tb.append_column(self.id_field, tile_tb[self.id_field]) - datum_ids.extend( - tb.filter(pc.field("bitmask") == True)[ - self.id_field - ].to_pylist() - ) + datum_ids.extend(tb.filter(pc.field("bitmask") == True)[self.id_field].to_pylist()) except Exception as e: raise Exception(f"Failed to fetch datums in tag. {e}") return datum_ids @@ -799,59 +706,55 @@ def _download_tag(self, tag_name: str, overwrite: Optional[bool] = False): all_quads = list(self.projection._tiles_in_order(coords_only=True)) ordered_tag_paths = [] for quad in tqdm(all_quads): - quad_str = os.path.join(*[str(q) for q in quad]) - filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" - path = self.projection.tile_destination / Path(filename) - download_attempt = 0 - download_success = False - while download_attempt < 3 and not download_success: - download_attempt += 1 - if not path.exists() or overwrite: - download_feather( - root_url + filename, path, headers=self.dataset.header - ) - try: - ipc.open_file(path).schema - download_success = True - except pa.ArrowInvalid: - path.unlink(missing_ok=True) - - if not download_success: - raise Exception(f"Failed to download tag {tag_name}.") - ordered_tag_paths.append(path) + if isinstance(quad, Tuple): + quad_str = os.path.join(*[str(q) for q in quad]) + filename = quad_str + "." + f"_tag.{tag_definition_id}" + ".feather" + path = self.projection.tile_destination / Path(filename) + download_attempt = 0 + download_success = False + while download_attempt < 3 and not download_success: + download_attempt += 1 + if not path.exists() or overwrite: + download_feather(root_url + filename, path, headers=self.dataset.header) + try: + ipc.open_file(path).schema + download_success = True + except pa.ArrowInvalid: + path.unlink(missing_ok=True) + + if not download_success: + raise Exception(f"Failed to download tag {tag_name}.") + ordered_tag_paths.append(path) return ordered_tag_paths def _remove_outdated_tag_files(self, tag_definition_ids: List[str]): - """ + ''' Attempts to remove outdated tag files based on tag definition ids. Any tag with a definition not in tag_definition_ids will be deleted. Args: tag_definition_ids: A list of tag definition ids to keep. - """ + ''' # NOTE: This currently only gets triggered on `df` property all_quads = list(self.projection._tiles_in_order(coords_only=True)) for quad in tqdm(all_quads): - quad_str = os.path.join(*[str(q) for q in quad]) - tile = self.projection.tile_destination / Path(quad_str) - tile_dir = tile.parent - if tile_dir.exists(): - tagged_files = tile_dir.glob("*_tag*") - for file in tagged_files: - tag_definition_id = file.name.split(".")[-2] - if tag_definition_id in tag_definition_ids: - try: - file.unlink() - except PermissionError: - print( - "Permission denied: unable to delete outdated tag file. Skipping" - ) - return - except Exception as e: - print( - f"Exception occurred when trying to delete outdated tag file: {e}. Skipping" - ) - return + if isinstance(quad, Tuple): + quad_str = os.path.join(*[str(q) for q in quad]) + tile = self.projection.tile_destination / Path(quad_str) + tile_dir = tile.parent + if tile_dir.exists(): + tagged_files = tile_dir.glob('*_tag*') + for file in tagged_files: + tag_definition_id = file.name.split(".")[-2] + if tag_definition_id in tag_definition_ids: + try: + file.unlink() + except PermissionError: + print("Permission denied: unable to delete outdated tag file. Skipping") + return + except Exception as e: + print(f"Exception occurred when trying to delete outdated tag file: {e}. Skipping") + return def add(self, ids: List[str], tags: List[str]): # ''' @@ -889,54 +792,43 @@ class AtlasMapData: you uploaded with your project. """ - def __init__(self, projection: "AtlasProjection", fields=None): + def __init__(self, projection: "AtlasProjection", fields=None): # type: ignore self.projection = projection self.dataset = projection.dataset self.id_field = self.projection.dataset.id_field - self._tb = None self.fields = fields try: # Run fetch_tiles first to guarantee existence of quad feather files self._basic_data: pa.Table = self.projection._fetch_tiles() sidecars = self._download_data(fields=fields) - self._read_prefetched_tiles_with_sidecars(sidecars) + self._tb = self._read_prefetched_tiles_with_sidecars(sidecars) - except pa.lib.ArrowInvalid as e: + except pa.lib.ArrowInvalid as e: # type: ignore raise ValueError("Failed to fetch tiles for this map") - def _read_prefetched_tiles_with_sidecars(self, additional_sidecars=None): + def _read_prefetched_tiles_with_sidecars(self, additional_sidecars): tbs = [] - root = feather.read_table( - self.projection.tile_destination / Path("0/0/0.feather") - ) + root = feather.read_table(self.projection.tile_destination / Path("0/0/0.feather")) # type: ignore try: - small_sidecars = set( - [v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()] - ) + small_sidecars = set([v for k, v in json.loads(root.schema.metadata[b"sidecars"]).items()]) except KeyError: small_sidecars = set([]) for path in self.projection._tiles_in_order(): - tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) + tb = pa.feather.read_table(path).drop(["_id", "ix", "x", "y"]) # type: ignore for col in tb.column_names: if col[0] == "_": tb = tb.drop([col]) for sidecar_file in small_sidecars: - carfile = pa.feather.read_table( - path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True - ) + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{sidecar_file}.feather", memory_map=True) # type: ignore for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) for big_sidecar in additional_sidecars: fname = ( - base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode( - "utf-8" - ) - if big_sidecar != "datum_id" + base64.urlsafe_b64encode(big_sidecar.encode("utf-8")).decode("utf-8") + if big_sidecar != 'datum_id' else big_sidecar ) - carfile = pa.feather.read_table( - path.parent / f"{path.stem}.{fname}.feather", memory_map=True - ) + carfile = pa.feather.read_table(path.parent / f"{path.stem}.{fname}.feather", memory_map=True) # type: ignore for col in carfile.column_names: tb = tb.append_column(col, carfile[col]) tbs.append(tb) @@ -962,16 +854,11 @@ def _download_data(self, fields=None): ] else: for field in sidecars: - assert ( - field in self.dataset.dataset_fields - ), f"Field {field} not found in dataset fields." - encoded_sidecars = [ - base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") - for sidecar in sidecars - ] - if any(sidecar == "datum_id" for (field, sidecar) in registered_sidecars): - sidecars.append("datum_id") - encoded_sidecars.append("datum_id") + assert field in self.dataset.dataset_fields, f"Field {field} not found in dataset fields." + encoded_sidecars = [base64.urlsafe_b64encode(sidecar.encode("utf-8")).decode("utf-8") for sidecar in sidecars] + if any(sidecar == 'datum_id' for (field, sidecar) in registered_sidecars): + sidecars.append('datum_id') + encoded_sidecars.append('datum_id') for quad in tqdm(all_quads): for encoded_colname in encoded_sidecars: @@ -986,7 +873,7 @@ def _download_data(self, fields=None): return sidecars @property - def df(self) -> pandas.DataFrame: + def df(self) -> pd.DataFrame: """ A pandas DataFrame associating each datapoint on your map to their metadata. Converting to pandas DataFrame may materialize a large amount of data into memory. From 6b343475f301dcf9a8a79ed0c863a6bc14ad4253 Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Thu, 25 Apr 2024 15:24:12 -0400 Subject: [PATCH 5/5] relint --- nomic/dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nomic/dataset.py b/nomic/dataset.py index 4c3939af..6a4b8ef8 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -653,9 +653,7 @@ def _fetch_tiles(self, overwrite: bool = True): return self._tile_data - def _tiles_in_order( - self, coords_only=False - ) -> Iterator[Union[Tuple[int, int, int], Path]]: + def _tiles_in_order(self, coords_only=False) -> Iterator[Union[Tuple[int, int, int], Path]]: """ Returns: A list of all tiles in the projection in a fixed order so that all