diff --git a/nomic/dataset.py b/nomic/dataset.py index fce4272..52d6b95 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -202,9 +202,12 @@ def _get_organization_by_slug(self, slug: str): self.atlas_api_path + f"/v1/organization/{slug}", headers=self.header, ) - if response.status_code != 200: + if response.status_code in [401, 403, 404]: raise Exception(f"Organization not found: {slug}") - + + if response.status_code != 200: + raise Exception(f"Could not access organization with slug {slug}: {response.text}") + return response.json()["id"] def _get_dataset_by_slug_identifier(self, identifier: str): @@ -216,40 +219,24 @@ def _get_dataset_by_slug_identifier(self, identifier: str): Returns: Returns the requested dataset. """ - - if not self.is_valid_dataset_identifier(identifier=identifier): - raise Exception("Invalid dataset identifier") - - organization_slug = identifier.split("/")[0] - project_slug = identifier.split("/")[1] + try: + organization_slug, project_slug = identifier.split("/") + except ValueError: + raise ValueError(f"Invalid dataset identifier {identifier}") + response = requests.get( self.atlas_api_path + f"/v1/project/{organization_slug}/{project_slug}", headers=self.header, ) - if response.status_code == 403: - raise ValueError(response.json()["detail"]) + if response.status_code in [401, 403, 404]: + raise ValueError(response.json().get("detail", "Unable to find dataset")) if response.status_code != 200: - return None + raise Exception(f"Could not access dataset with identifier {identifier}: {response.text}") return response.json() - def is_valid_dataset_identifier(self, identifier: str): - """ - Checks if a string is a valid identifier for a dataset - - Args: - identifer: the organization slug and dataset slug separated by a slash - - Returns: - Returns the requested dataset. - """ - slugs = identifier.split("/") - if "/" not in identifier or len(slugs) != 2: - return False - return True - def _get_index_job(self, job_id: str): """ @@ -794,32 +781,31 @@ def __init__( self.meta = self._get_project_by_id(dataset_id) return - if not self.is_valid_dataset_identifier(identifier=str(identifier)): + if identifier and not "/" in identifier and re.match(r"^[a-z0-9-]+$", identifier): + # If the identifier is sluggy looking, assume the user is trying to load a project in their default organization. default_org_slug = self._get_current_users_main_organization()["slug"] + logger.warning( + f"Received identifier `{identifier}` that looks like a project name. Assuming you meant `{default_org_slug}/{identifier}`." + " This behavior is deprecated and will be removed in a future version." + ) identifier = default_org_slug + "/" + identifier - dataset = self._get_dataset_by_slug_identifier(identifier=str(identifier)) - - if dataset: # dataset already exists + try: + dataset = self._get_dataset_by_slug_identifier(identifier=str(identifier)) logger.info(f"Loading existing dataset `{identifier}`.") - dataset_id = dataset["id"] - - if dataset_id is None: # if there is no existing project, make a new one. - if unique_id_field is None: # if not all parameters are specified, we weren't trying to make a project - raise ValueError(f"Dataset `{identifier}` does not exist.") - - # if modality is None: - # raise ValueError("You must specify a modality when creating a new dataset.") - # - # assert modality in ['text', 'embedding'], "Modality must be either `text` or `embedding`" - assert identifier is not None - - dataset_id = self._create_project( - identifier=identifier, - description=description, - unique_id_field=unique_id_field, - is_public=is_public, - ) + dataset_id = dataset["id"] + except Exception: + if unique_id_field is None: + # if not all parameters are specified, we weren't trying to make a project + raise + else: + assert identifier is not None + dataset_id = self._create_project( + identifier=identifier, + description=description, + unique_id_field=unique_id_field, + is_public=is_public, + ) self.meta = self._get_project_by_id(project_id=dataset_id)