Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable zarr backend testing in data tests [3] #1094

Merged
merged 8 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Improvements
* Remove dev test from PR [PR #1092](https://github.com/catalystneuro/neuroconv/pull/1092)
* Run only the most basic testing while a PR is on draft [PR #1082](https://github.com/catalystneuro/neuroconv/pull/1082)
* Test that zarr backend_configuration works in gin data tests [PR #1094](https://github.com/catalystneuro/neuroconv/pull/1094)
* Consolidated weekly workflows into one workflow and added email notifications [PR #1088](https://github.com/catalystneuro/neuroconv/pull/1088)
* Avoid running link test when the PR is on draft [PR #1093](https://github.com/catalystneuro/neuroconv/pull/1093)

Expand Down
196 changes: 100 additions & 96 deletions src/neuroconv/tools/testing/data_interface_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,23 @@ def test_source_schema_valid(self):
schema = self.data_interface_cls.get_source_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
def test_conversion_options_schema_valid(self, setup_interface):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def test_metadata_schema_valid(self, setup_interface):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_metadata(self):
def test_metadata(self, setup_interface):
# Validate metadata now happens on the class itself
metadata = self.interface.get_metadata()
self.check_extracted_metadata(metadata)

def check_extracted_metadata(self, metadata: dict):
"""Override this method to make assertions about specific extracted metadata values."""
pass

def test_no_metadata_mutation(self, setup_interface):
"""Ensure the metadata object is not altered by `add_to_nwbfile` method."""

Expand All @@ -107,13 +111,35 @@ def test_no_metadata_mutation(self, setup_interface):
self.interface.add_to_nwbfile(nwbfile=nwbfile, metadata=metadata, **self.conversion_options)
assert metadata == metadata_before_add_method

def check_run_conversion_with_backend_configuration(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend(self, setup_interface, tmp_path, backend):

nwbfile_path = str(tmp_path / f"conversion_with_backend{backend}-{self.test_name}.nwb")

metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)

if backend == "zarr":
with NWBZarrIO(path=nwbfile_path, mode="r") as io:
io.read()

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend_configuration(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_path = str(tmp_path / f"conversion_with_backend_configuration{backend}-{self.test_name}.nwb")

nwbfile = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
backend_configuration = self.interface.get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
self.interface.run_conversion(
Expand All @@ -125,6 +151,42 @@ def check_run_conversion_with_backend_configuration(
**self.conversion_options,
)

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_configure_backend_for_equivalent_nwbfiles(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_1 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
nwbfile_2 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)

backend_configuration = get_default_backend_configuration(nwbfile=nwbfile_1, backend=backend)
configure_backend(nwbfile=nwbfile_2, backend_configuration=backend_configuration)

def test_all_conversion_checks(self, setup_interface, tmp_path):
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")
self.nwbfile_path = nwbfile_path

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()

@abstractmethod
def check_read_nwb(self, nwbfile_path: str):
"""Read the produced NWB file and compare it to the interface."""
pass

def run_custom_checks(self):
"""Override this in child classes to inject additional custom checks."""
pass

def check_run_conversion_in_nwbconverter_with_backend(
self, nwbfile_path: str, backend: Literal["hdf5", "zarr"] = "hdf5"
):
Expand Down Expand Up @@ -174,73 +236,6 @@ class TestNWBConverter(NWBConverter):
conversion_options=conversion_options,
)

@abstractmethod
def check_read_nwb(self, nwbfile_path: str):
"""Read the produced NWB file and compare it to the interface."""
pass

def check_extracted_metadata(self, metadata: dict):
"""Override this method to make assertions about specific extracted metadata values."""
pass

def run_custom_checks(self):
"""Override this in child classes to inject additional custom checks."""
pass

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_run_conversion_with_backend(self, setup_interface, tmp_path, backend):

nwbfile_path = str(tmp_path / f"conversion_with_backend{backend}-{self.test_name}.nwb")

metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

self.interface.run_conversion(
nwbfile_path=nwbfile_path,
overwrite=True,
metadata=metadata,
backend=backend,
**self.conversion_options,
)

if backend == "zarr":
with NWBZarrIO(path=nwbfile_path, mode="r") as io:
io.read()

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_configure_backend_for_equivalent_nwbfiles(self, setup_interface, tmp_path, backend):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
metadata["NWBFile"].update(session_start_time=datetime.now().astimezone())

nwbfile_1 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)
nwbfile_2 = self.interface.create_nwbfile(metadata=metadata, **self.conversion_options)

backend_configuration = get_default_backend_configuration(nwbfile=nwbfile_1, backend=backend)
configure_backend(nwbfile=nwbfile_2, backend_configuration=backend_configuration)

def test_all_conversion_checks(self, setup_interface, tmp_path):
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")
self.nwbfile_path = nwbfile_path

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()


class TemporalAlignmentMixin:
"""
Expand Down Expand Up @@ -718,27 +713,6 @@ def check_shift_segment_timestamps_by_starting_times(self):
):
assert_array_equal(x=retrieved_aligned_timestamps, y=expected_aligned_timestamps)

def test_all_conversion_checks(self, setup_interface, tmp_path):
# The fixture `setup_interface` sets up the necessary objects
interface, test_name = setup_interface

# Create a unique test name and file path
nwbfile_path = str(tmp_path / f"{self.__class__.__name__}_{self.test_name}.nwb")

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()

self.check_run_conversion_in_nwbconverter_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")
self.check_run_conversion_in_nwbconverter_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_run_conversion_with_backend_configuration(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# Any extra custom checks to run
self.run_custom_checks()

def test_interface_alignment(self, setup_interface):

# TODO sorting can have times without associated recordings, test this later
Expand Down Expand Up @@ -872,12 +846,21 @@ class MedPCInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
A mixin for testing MedPC interfaces.
"""

def test_metadata(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_metadata_schema_valid(self):
pass

def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_no_metadata_mutation(self):
pass

Expand All @@ -888,6 +871,10 @@ def check_metadata_schema_valid(self):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def check_metadata(self):
schema = self.interface.get_metadata_schema()
metadata = self.interface.get_metadata()
Expand Down Expand Up @@ -1158,9 +1145,8 @@ def check_read_nwb(self, nwbfile_path: str):
assert one_photon_series.starting_frame is None
assert one_photon_series.timestamps.shape == (15,)

imaging_extractor = self.interface.imaging_extractor
times_from_extractor = imaging_extractor._times
assert_array_equal(one_photon_series.timestamps, times_from_extractor)
interface_times = self.interface.get_original_timestamps()
assert_array_equal(one_photon_series.timestamps, interface_times)


class ScanImageSinglePlaneImagingInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
Expand Down Expand Up @@ -1235,25 +1221,43 @@ def check_read_nwb(self, nwbfile_path: str):
class TDTFiberPhotometryInterfaceMixin(DataInterfaceTestMixin, TemporalAlignmentMixin):
"""Mixin for testing TDT Fiber Photometry interfaces."""

def test_metadata(self):
pass

def test_metadata_schema_valid(self):
pass

def test_no_metadata_mutation(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_no_metadata_mutation(self):
pass

def test_configure_backend_for_equivalent_nwbfiles(self):
pass

def check_metadata(self):
# Validate metadata now happens on the class itself
metadata = self.interface.get_metadata()
self.check_extracted_metadata(metadata)

def check_metadata_schema_valid(self):
schema = self.interface.get_metadata_schema()
Draft7Validator.check_schema(schema=schema)

def check_conversion_options_schema_valid(self):
schema = self.interface.get_conversion_options_schema()
Draft7Validator.check_schema(schema=schema)

def check_no_metadata_mutation(self, metadata: dict):
"""Ensure the metadata object was not altered by `add_to_nwbfile` method."""

Expand Down
22 changes: 15 additions & 7 deletions tests/test_on_data/ecephys/test_recording_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_add_channel_metadata_to_nwb(self, setup_interface):
else:
assert expected_value == extracted_value

# Test addition to electrodes table
# Test addition to electrodes table!~
with NWBHDF5IO(self.nwbfile_path, "r") as io:
nwbfile = io.read()
electrode_table = nwbfile.electrodes.to_dataframe()
Expand All @@ -176,9 +176,6 @@ class TestEDFRecordingInterface(RecordingExtractorInterfaceTestMixin):
interface_kwargs = dict(file_path=str(ECEPHY_DATA_PATH / "edf" / "edf+C.edf"))
save_directory = OUTPUT_PATH

def check_extracted_metadata(self, metadata: dict):
assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 3, 2, 10, 42, 19)

def check_run_conversion_with_backend(self, nwbfile_path: str, backend="hdf5"):
metadata = self.interface.get_metadata()
if "session_start_time" not in metadata["NWBFile"]:
Expand All @@ -198,11 +195,10 @@ def test_all_conversion_checks(self, setup_interface, tmp_path):
self.nwbfile_path = nwbfile_path

# Now run the checks using the setup objects
self.check_conversion_options_schema_valid()
self.check_metadata()
metadata = self.interface.get_metadata()
assert metadata["NWBFile"]["session_start_time"] == datetime(2022, 3, 2, 10, 42, 19)

self.check_run_conversion_with_backend(nwbfile_path=nwbfile_path, backend="hdf5")

self.check_read_nwb(nwbfile_path=nwbfile_path)

# EDF has simultaneous access issues; can't have multiple interfaces open on the same file at once...
Expand All @@ -215,12 +211,24 @@ def test_no_metadata_mutation(self):
def test_run_conversion_with_backend(self):
pass

def test_run_conversion_with_backend_configuration(self):
pass

def test_interface_alignment(self):
pass

def test_configure_backend_for_equivalent_nwbfiles(self):
pass

def test_conversion_options_schema_valid(self):
pass

def test_metadata(self):
pass

def test_conversion_options_schema_valid(self):
pass


class TestIntanRecordingInterfaceRHS(RecordingExtractorInterfaceTestMixin):
data_interface_cls = IntanRecordingInterface
Expand Down
4 changes: 4 additions & 0 deletions tests/test_ophys/test_ophys_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class TestMockImagingInterface(ImagingExtractorInterfaceTestMixin):
data_interface_cls = MockImagingInterface
interface_kwargs = dict()

# TODO: fix this by setting a seed on the dummy imaging extractor
def test_all_conversion_checks(self):
pass


class TestMockSegmentationInterface(SegmentationExtractorInterfaceTestMixin):

Expand Down
Loading