Skip to content

Commit

Permalink
test_hdf typing (#578)
Browse files Browse the repository at this point in the history
typing
  • Loading branch information
hussain-jafari authored Feb 3, 2025
1 parent 04d2d6a commit 901a16a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 44 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**3.2.19 - 02/03/25**

- Type-hinting: Fix mypy errors in tests/framework/artifact/test_hdf.py

**3.2.18 - 01/28/25**

- Type-hinting: Fix mypy errors in tests/framework/artifact/test_manager.py
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ exclude = [
'src/vivarium/interface/cli.py',
'src/vivarium/testing_utilities.py',
'tests/examples/test_disease_model.py',
'tests/framework/artifact/test_hdf.py',
'tests/framework/components/mocks.py',
'tests/framework/components/test_component.py',
'tests/framework/components/test_manager.py',
Expand Down Expand Up @@ -83,6 +82,6 @@ module = [
"ipywidgets.*",
"Ipython.*",
"dill",
"tables",
"tables.*"
]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion src/vivarium/framework/artifact/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

import pandas as pd
import tables
from tables.nodes import filenode # type: ignore [import-untyped]
from tables.nodes import filenode

####################
# Public interface #
Expand Down
97 changes: 56 additions & 41 deletions tests/framework/artifact/test_hdf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import json
import random
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import pytest
import tables
from pytest_mock import MockerFixture
from tables.file import File
from tables.nodes import filenode

from vivarium.framework.artifact import hdf
from vivarium.framework.artifact.hdf import EntityKey
from vivarium.testing_utilities import build_table

_KEYS = [
Expand All @@ -23,12 +27,13 @@


@pytest.fixture
def hdf_keys():
def hdf_keys() -> list[str]:
return _KEYS


@pytest.fixture(params=_KEYS)
def hdf_key(request):
def hdf_key(request: pytest.FixtureRequest) -> str:
assert isinstance(request.param, str)
return request.param


Expand All @@ -41,16 +46,16 @@ def hdf_key(request):
"cause.all_cause.kind_of_new",
]
)
def mock_key(request):
def mock_key(request: pytest.FixtureRequest) -> EntityKey:
return hdf.EntityKey(request.param)


@pytest.fixture(params=[[], {}, ["data"], {"thing": "value"}, "bananas"])
def json_data(request):
def json_data(request: pytest.FixtureRequest) -> Any:
return request.param


def test_touch_no_file(mocker):
def test_touch_no_file(mocker: MockerFixture) -> None:
path = Path("not/an/existing/path.hdf")
tables_mock = mocker.patch("vivarium.framework.artifact.hdf.tables")

Expand All @@ -59,7 +64,7 @@ def test_touch_no_file(mocker):
tables_mock.reset_mock()


def test_touch_exists_but_not_hdf_file_path(hdf_file_path):
def test_touch_exists_but_not_hdf_file_path(hdf_file_path: Path) -> None:
dir_path = Path(hdf_file_path).parent
with pytest.raises(ValueError):
hdf.touch(dir_path)
Expand All @@ -68,7 +73,7 @@ def test_touch_exists_but_not_hdf_file_path(hdf_file_path):
hdf.touch(non_hdf_path)


def test_touch_existing_file(tmpdir):
def test_touch_existing_file(tmpdir: Path) -> None:
path = f"{str(tmpdir)}/test.hdf"

hdf.touch(path)
Expand All @@ -80,7 +85,7 @@ def test_touch_existing_file(tmpdir):
assert hdf.get_keys(path) == []


def test_write_df(hdf_file_path, mock_key, mocker):
def test_write_df(hdf_file_path: Path, mock_key: EntityKey, mocker: MockerFixture) -> None:
df_mock = mocker.patch("vivarium.framework.artifact.hdf._write_pandas_data")
data = pd.DataFrame(np.random.random((10, 3)), columns=["a", "b", "c"], index=range(10))

Expand All @@ -89,13 +94,15 @@ def test_write_df(hdf_file_path, mock_key, mocker):
df_mock.assert_called_once_with(hdf_file_path, mock_key, data)


def test_write_json(hdf_file_path, mock_key, json_data, mocker):
def test_write_json(
hdf_file_path: Path, mock_key: EntityKey, json_data: list[str], mocker: MockerFixture
) -> None:
json_mock = mocker.patch("vivarium.framework.artifact.hdf._write_json_blob")
hdf.write(hdf_file_path, mock_key, json_data)
json_mock.assert_called_once_with(hdf_file_path, mock_key, json_data)


def test_load(hdf_file_path, hdf_key):
def test_load(hdf_file_path: Path, hdf_key: str) -> None:
key = hdf.EntityKey(hdf_key)
data = hdf.load(hdf_file_path, key, filter_terms=None, column_filters=None)
if "restrictions" in key or "versions" in key:
Expand All @@ -106,7 +113,7 @@ def test_load(hdf_file_path, hdf_key):
assert isinstance(data, pd.DataFrame)


def test_load_with_invalid_filters(hdf_file_path, hdf_key):
def test_load_with_invalid_filters(hdf_file_path: Path, hdf_key: str) -> None:
key = hdf.EntityKey(hdf_key)
data = hdf.load(hdf_file_path, key, filter_terms=["fake_filter==0"], column_filters=None)
if "restrictions" in key or "versions" in key:
Expand All @@ -117,7 +124,7 @@ def test_load_with_invalid_filters(hdf_file_path, hdf_key):
assert isinstance(data, pd.DataFrame)


def test_load_with_valid_filters(hdf_file_path, hdf_key):
def test_load_with_valid_filters(hdf_file_path: Path, hdf_key: str) -> None:
key = hdf.EntityKey(hdf_key)
data = hdf.load(hdf_file_path, key, filter_terms=["year == 2006"], column_filters=None)
if "restrictions" in key or "versions" in key:
Expand All @@ -130,7 +137,7 @@ def test_load_with_valid_filters(hdf_file_path, hdf_key):
assert set(data.year) == {2006}


def test_load_filter_empty_data_frame_index(hdf_file_path):
def test_load_filter_empty_data_frame_index(hdf_file_path: Path) -> None:
key = hdf.EntityKey("cause.test.prevalence")
data = pd.DataFrame(data={"age": range(10), "year": range(10), "draw": range(10)})
data = data.set_index(list(data.columns))
Expand All @@ -143,18 +150,20 @@ def test_load_filter_empty_data_frame_index(hdf_file_path):
assert loaded_data.year.unique() == 4


def test_remove(hdf_file_path, hdf_key):
def test_remove(hdf_file_path: Path, hdf_key: str) -> None:
key = hdf.EntityKey(hdf_key)
hdf.remove(hdf_file_path, key)
with tables.open_file(str(hdf_file_path)) as file:
assert key.path not in file


def test_get_keys(hdf_file_path, hdf_keys):
def test_get_keys(hdf_file_path: Path, hdf_keys: list[str]) -> None:
assert sorted(hdf.get_keys(hdf_file_path)) == sorted(hdf_keys)


def test_write_json_blob(hdf_file_path, mock_key, json_data):
def test_write_json_blob(
hdf_file_path: Path, mock_key: EntityKey, json_data: list[str]
) -> None:
hdf._write_json_blob(hdf_file_path, mock_key, json_data)

with tables.open_file(str(hdf_file_path)) as file:
Expand All @@ -164,15 +173,15 @@ def test_write_json_blob(hdf_file_path, mock_key, json_data):
assert data == json_data


def test_write_empty_data_frame(hdf_file_path):
def test_write_empty_data_frame(hdf_file_path: Path) -> None:
key = hdf.EntityKey("cause.test.prevalence")
data = pd.DataFrame(columns=("age", "year", "sex", "draw", "location", "value"))

with pytest.raises(ValueError):
hdf._write_pandas_data(hdf_file_path, key, data)


def test_write_empty_data_frame_index(hdf_file_path):
def test_write_empty_data_frame_index(hdf_file_path: Path) -> None:
key = hdf.EntityKey("cause.test.prevalence")
data = pd.DataFrame(data={"age": range(10), "year": range(10), "draw": range(10)})
data = data.set_index(list(data.columns))
Expand All @@ -185,7 +194,7 @@ def test_write_empty_data_frame_index(hdf_file_path):
assert written_data.equals(data)


def test_write_load_empty_data_frame_index(hdf_file_path):
def test_write_load_empty_data_frame_index(hdf_file_path: Path) -> None:
key = hdf.EntityKey("cause.test.prevalence")
data = pd.DataFrame(data={"age": range(10), "year": range(10), "draw": range(10)})
data = data.set_index(list(data.columns))
Expand All @@ -195,84 +204,90 @@ def test_write_load_empty_data_frame_index(hdf_file_path):
assert loaded_data.equals(data)


def test_write_data_frame(hdf_file_path):
def test_write_data_frame(hdf_file_path: Path) -> None:
key = hdf.EntityKey("cause.test.prevalence")
data = build_table(
lambda x: random.choice([0, 1]),
key_columns={"draw": [0, 1], "location": ["Kenya"]},
)

non_val_columns = data.columns.difference({"value"})
non_val_columns = data.columns.difference(["value"])
data = data.set_index(list(non_val_columns))

hdf._write_pandas_data(hdf_file_path, key, data)

written_data = pd.read_hdf(hdf_file_path, key.path)
assert written_data.equals(data)
assert isinstance(written_data, pd.DataFrame)
pd.testing.assert_frame_equal(written_data, data)

filter_terms = ["draw == 0"]
filter_terms = "draw == 0"
written_data = pd.read_hdf(hdf_file_path, key.path, where=filter_terms)
assert written_data.equals(data.xs(0, level="draw", drop_level=False))

draw_0_data = data.xs(0, level="draw", drop_level=False)
assert isinstance(written_data, pd.DataFrame)
assert isinstance(draw_0_data, pd.DataFrame)
pd.testing.assert_frame_equal(written_data, draw_0_data)


def test_get_keys_private(hdf_file, hdf_keys):
def test_get_keys_private(hdf_file: File, hdf_keys: list[str]) -> None:
assert sorted(hdf._get_keys(hdf_file.root)) == sorted(hdf_keys)


def test_get_node_name(hdf_file, hdf_key):
def test_get_node_name(hdf_file: File, hdf_key: str) -> None:
key = hdf.EntityKey(hdf_key)
assert hdf._get_node_name(hdf_file.get_node(key.path)) == key.measure


def test_get_valid_filter_terms_all_invalid(hdf_key, hdf_file):
def test_get_valid_filter_terms_all_invalid(hdf_key: str, hdf_file: File) -> None:
node = hdf_file.get_node(hdf.EntityKey(hdf_key).path)
if not isinstance(node, tables.earray.EArray):
columns = node.table.colnames
invalid_filter_terms = _construct_no_valid_filters(columns)
assert hdf._get_valid_filter_terms(invalid_filter_terms, columns) is None


def test_get_valid_filter_terms_all_valid(hdf_key, hdf_file):
def test_get_valid_filter_terms_all_valid(hdf_key: str, hdf_file: File) -> None:
node = hdf_file.get_node(hdf.EntityKey(hdf_key).path)
if not isinstance(node, tables.earray.EArray):
columns = node.table.colnames
valid_filter_terms = _construct_all_valid_filters(columns)
assert set(hdf._get_valid_filter_terms(valid_filter_terms, columns)) == set(
valid_filter_terms
)
result = hdf._get_valid_filter_terms(valid_filter_terms, columns)
assert result is not None
assert set(result) == set(valid_filter_terms)


def test_get_valid_filter_terms_some_valid(hdf_key, hdf_file):
def test_get_valid_filter_terms_some_valid(hdf_key: str, hdf_file: File) -> None:
node = hdf_file.get_node(hdf.EntityKey(hdf_key).path)
if not isinstance(node, tables.earray.EArray):
columns = node.table.colnames
invalid_filter_terms = _construct_no_valid_filters(columns)
valid_filter_terms = _construct_all_valid_filters(columns)
all_terms = invalid_filter_terms + valid_filter_terms
result = hdf._get_valid_filter_terms(all_terms, columns)
assert result is not None
assert set(result) == set(valid_filter_terms)


def test_get_valid_filter_terms_no_terms():
def test_get_valid_filter_terms_no_terms() -> None:
assert hdf._get_valid_filter_terms(None, []) is None


def _construct_no_valid_filters(columns):
def _construct_no_valid_filters(columns: list[str]) -> list[str]:
fake_cols = [
c[1:] for c in columns
] # strip out the first char to make a list of all fake cols
terms = [c + " <= 0" for c in fake_cols]
return _complicate_terms_to_parse(terms)


def _construct_all_valid_filters(columns):
def _construct_all_valid_filters(columns: list[str]) -> list[str]:
terms = [
c + "=0" for c in columns
] # assume c is numeric - we won't actually apply filter
return _complicate_terms_to_parse(terms)


def _complicate_terms_to_parse(terms):
def _complicate_terms_to_parse(terms: list[str]) -> list[str]:
n_terms = len(terms)
if n_terms > 1:
# throw in some parens and ifs/ands
Expand All @@ -282,7 +297,7 @@ def _complicate_terms_to_parse(terms):
return ["(" + t + ")" for t in terms]


def test_EntityKey_init_failure():
def test_EntityKey_init_failure() -> None:
bad_keys = ["hello", "a.b.c.d", "", ".", ".coconut", "a.", "a..c"]

for k in bad_keys:
Expand All @@ -291,7 +306,7 @@ def test_EntityKey_init_failure():
hdf.EntityKey(k)


def test_EntityKey_no_name():
def test_EntityKey_no_name() -> None:
type_ = "population"
measure = "structure"
key = hdf.EntityKey(f"{type_}.{measure}")
Expand All @@ -306,7 +321,7 @@ def test_EntityKey_no_name():
assert key.with_measure("age_groups") == hdf.EntityKey("population.age_groups")


def test_EntityKey_with_name():
def test_EntityKey_with_name() -> None:
type_ = "cause"
name = "diarrheal_diseases"
measure = "incidence"
Expand All @@ -322,15 +337,15 @@ def test_EntityKey_with_name():
assert key.with_measure("prevalence") == hdf.EntityKey(f"{type_}.{name}.prevalence")


def test_entity_key_equality():
def test_entity_key_equality() -> None:
type_ = "cause"
name = "diarrheal_diseases"
measure = "incidence"
string = f"{type_}.{name}.{measure}"
key = hdf.EntityKey(string)

class NonString:
def __str__(self):
def __str__(self) -> str:
return string

nonstring = NonString()
Expand Down

0 comments on commit 901a16a

Please sign in to comment.