Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix multiple input types to map data #133

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ lint:
pretty: isort black

test:
source env/bin/activate; pytest -s nomic/tests
source env/bin/activate; pip install pandas; pytest -s nomic/tests

clean:
rm -rf {.pytest_cache,env,nomic.egg-info}
find . | grep -E "(__pycache__|\.pyc|\.pyo$\)" | xargs rm -rf
13 changes: 10 additions & 3 deletions nomic/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@

from .project import AtlasProject
from .settings import *
try:
import pandas as pd
from pandas import DataFrame
except ImportError:
pd = None
DataFrame = None
import pyarrow as pa
from typing import Union
from .utils import b64int, get_random_name


def map_embeddings(
embeddings: np.array,
data: List[Dict] = None,
data: Union[List[Dict], "DataFrame", pa.Table, None] = None,
id_field: str = None,
name: str = None,
description: str = None,
Expand Down Expand Up @@ -146,7 +153,7 @@ def map_embeddings(


def map_text(
data: List[Dict],
data: Union[List[Dict], "DataFrame", pa.Table],
indexed_field: str,
id_field: str = None,
name: str = None,
Expand Down
12 changes: 11 additions & 1 deletion nomic/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,19 @@ def _validate_map_data_inputs(self, colorable_fields, id_field, data):

if id_field in colorable_fields:
raise Exception(f'Cannot color by unique id field: {id_field}')

names: List[str]
if isinstance(data, pa.Table):
names = data.column_names
elif pd is not None and isinstance(data, pd.DataFrame):
names = list(data.columns)
elif isinstance(data, list):
names = list(data[0].keys())
else:
raise ValueError("Invalid data type for data. Must be pyarrow.Table, pandas.DataFrame, or list of dicts.")

for field in colorable_fields:
if field not in data[0]:
if field not in names:
raise Exception(f"Cannot color by field `{field}` as it is not present in the metadata.")

def _get_current_users_main_organization(self):
Expand Down
56 changes: 55 additions & 1 deletion nomic/tests/test_atlas_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import pytest
import requests
from nomic import AtlasProject, atlas

import pyarrow as pa
import pandas as pd

def gen_random_datetime(min_year=1900, max_year=datetime.now().year):
# generate a datetime in format yyyy-mm-dd hh:mm:ss.000000
Expand All @@ -28,6 +29,7 @@ def test_map_idless_embeddings():
AtlasProject(name="test1").delete()



def test_map_embeddings_with_errors():
num_embeddings = 20
embeddings = np.random.rand(num_embeddings, 10)
Expand Down Expand Up @@ -58,6 +60,58 @@ def test_map_embeddings_with_errors():
reset_project_if_exists=True,
)

def test_map_text_pandas():
size = 20
data = pd.DataFrame({
'field': [str(uuid.uuid4()) for i in range(size)],
'id': [str(uuid.uuid4()) for i in range(size)],
'color': [random.choice(['red', 'blue', 'green']) for i in range(size)],
})

project = atlas.map_text(
name='UNITTEST_pandas_text',
id_field='id',
indexed_field="color",
data=data,
is_public=True,
colorable_fields=['color'],
reset_project_if_exists=True,
)

map = project.get_map(name='UNITTEST_pandas_text')

assert project.total_datums == 20

project.delete()

def test_map_embeddings_pandas():
num_embeddings = 20
embeddings = np.random.rand(num_embeddings, 10)
data = pd.DataFrame({
'field': [str(uuid.uuid4()) for i in range(len(embeddings))],
'id': [str(uuid.uuid4()) for i in range(len(embeddings))],
'color': [random.choice(['red', 'blue', 'green']) for i in range(len(embeddings))],
})

project = atlas.map_embeddings(
embeddings=embeddings,
name='UNITTEST_pandas',
id_field='id',
data=data,
is_public=True,
colorable_fields=['color'],
reset_project_if_exists=True,
)

map = project.get_map(name='UNITTEST_pandas')

time.sleep(10)
with tempfile.TemporaryDirectory() as td:
retrieved_embeddings = map.download_embeddings(td)

assert project.total_datums == num_embeddings

project.delete()

def test_map_text_errors():
# no indexed field
Expand Down