From e49618ef37d1857611b1add69dbf0b4274b0f546 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Mon, 9 Dec 2024 22:13:58 -0600 Subject: [PATCH] add tests and refactor --- .../ecephys/spikeglx/spikeglxnidqinterface.py | 232 +++++++++++++++--- .../ecephys/test_aux_interfaces.py | 100 -------- .../ecephys/test_nidq_interface.py | 52 ++++ 3 files changed, 253 insertions(+), 131 deletions(-) delete mode 100644 tests/test_on_data/ecephys/test_aux_interfaces.py create mode 100644 tests/test_on_data/ecephys/test_nidq_interface.py diff --git a/src/neuroconv/datainterfaces/ecephys/spikeglx/spikeglxnidqinterface.py b/src/neuroconv/datainterfaces/ecephys/spikeglx/spikeglxnidqinterface.py index 1d7079716..4895221e0 100644 --- a/src/neuroconv/datainterfaces/ecephys/spikeglx/spikeglxnidqinterface.py +++ b/src/neuroconv/datainterfaces/ecephys/spikeglx/spikeglxnidqinterface.py @@ -1,39 +1,35 @@ from pathlib import Path -from typing import Optional +from typing import Literal, Optional import numpy as np from pydantic import ConfigDict, DirectoryPath, FilePath, validate_call +from pynwb import NWBFile +from pynwb.base import TimeSeries from .spikeglx_utils import get_session_start_time -from ..baserecordingextractorinterface import BaseRecordingExtractorInterface +from ....basedatainterface import BaseDataInterface from ....tools.signal_processing import get_rising_frames_from_ttl -from ....utils import get_json_schema_from_method_signature +from ....tools.spikeinterface.spikeinterface import _recording_traces_to_hdmf_iterator +from ....utils import ( + calculate_regular_series_rate, + get_json_schema_from_method_signature, +) -class SpikeGLXNIDQInterface(BaseRecordingExtractorInterface): +class SpikeGLXNIDQInterface(BaseDataInterface): """Primary data interface class for converting the high-pass (ap) SpikeGLX format.""" display_name = "NIDQ Recording" - keywords = BaseRecordingExtractorInterface.keywords + ("Neuropixels",) + keywords = ("Neuropixels", "nidq", "NIDQ", "SpikeGLX") associated_suffixes = (".nidq", ".meta", ".bin") info = "Interface for NIDQ board recording data." - ExtractorName = "SpikeGLXRecordingExtractor" - stream_id = "nidq" - @classmethod def get_source_schema(cls) -> dict: source_schema = get_json_schema_from_method_signature(method=cls.__init__, exclude=["x_pitch", "y_pitch"]) source_schema["properties"]["file_path"]["description"] = "Path to SpikeGLX .nidq file." return source_schema - def _source_data_to_extractor_kwargs(self, source_data: dict) -> dict: - - extractor_kwargs = source_data.copy() - extractor_kwargs["folder_path"] = self.folder_path - extractor_kwargs["stream_id"] = self.stream_id - return extractor_kwargs - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, @@ -72,18 +68,35 @@ def __init__( if folder_path is not None: self.folder_path = Path(folder_path) + from spikeinterface.extractors import SpikeGLXRecordingExtractor + + self.recording_extractor = SpikeGLXRecordingExtractor( + folder_path=self.folder_path, + stream_id="nidq", + all_annotations=True, + ) + + channel_ids = self.recording_extractor.get_channel_ids() + analog_channel_signatures = ["XA", "MA"] + self.analog_channel_ids = [ch for ch in channel_ids if "XA" in ch or "MA" in ch] + self.has_analog_channels = len(self.analog_channel_ids) > 0 + self.has_digital_channels = len(self.analog_channel_ids) < len(channel_ids) + if self.has_digital_channels: + from spikeinterface.extractors import SpikeGLXEventExtractor + + self.event_extractor = SpikeGLXEventExtractor(folder_path=self.folder_path) + super().__init__( verbose=verbose, load_sync_channel=load_sync_channel, es_key=es_key, + folder_path=self.folder_path, + file_path=file_path, ) - self.source_data.update(file_path=str(file_path)) - self.recording_extractor.set_property( - key="group_name", values=["NIDQChannelGroup"] * self.recording_extractor.get_num_channels() - ) + self.subset_channels = None - signal_info_key = (0, self.stream_id) # Key format is (segment_index, stream_id) + signal_info_key = (0, "nidq") # Key format is (segment_index, stream_id) self._signals_info_dict = self.recording_extractor.neo_reader.signals_info_dict[signal_info_key] self.meta = self._signals_info_dict["meta"] @@ -101,24 +114,181 @@ def get_metadata(self) -> dict: manufacturer="National Instruments", ) - # Add groups metadata - metadata["Ecephys"]["Device"] = [device] + metadata["Devices"] = [device] - metadata["Ecephys"]["ElectrodeGroup"][0].update( - name="NIDQChannelGroup", description="A group representing the NIDQ channels.", device=device["name"] - ) - metadata["Ecephys"]["Electrodes"] = [ - dict(name="group_name", description="Name of the ElectrodeGroup this electrode is a part of."), - ] - metadata["Ecephys"]["ElectricalSeriesNIDQ"][ - "description" - ] = "Raw acquisition traces from the NIDQ (.nidq.bin) channels." return metadata def get_channel_names(self) -> list[str]: """Return a list of channel names as set in the recording extractor.""" return list(self.recording_extractor.get_channel_ids()) + def add_to_nwbfile( + self, + nwbfile: NWBFile, + metadata: Optional[dict] = None, + stub_test: bool = False, + starting_time: Optional[float] = None, + write_as: Literal["raw", "lfp", "processed"] = "raw", + write_electrical_series: bool = True, + iterator_type: Optional[str] = "v2", + iterator_opts: Optional[dict] = None, + always_write_timestamps: bool = False, + ): + """ + Add NIDQ board data to an NWB file, including both analog and digital channels if present. + + Parameters + ---------- + nwbfile : NWBFile + The NWB file to which the NIDQ data will be added + metadata : Optional[dict], default: None + Metadata dictionary with device information. If None, uses default metadata + stub_test : bool, default: False + If True, only writes a small amount of data for testing + starting_time : Optional[float], default: None + DEPRECATED: Will be removed in June 2025. Starting time offset for the TimeSeries + write_as : Literal["raw", "lfp", "processed"], default: "raw" + DEPRECATED: Will be removed in June 2025. Specifies how to write the data + write_electrical_series : bool, default: True + DEPRECATED: Will be removed in June 2025. Whether to write electrical series data + iterator_type : Optional[str], default: "v2" + Type of iterator to use for data streaming + iterator_opts : Optional[dict], default: None + Additional options for the iterator + always_write_timestamps : bool, default: False + If True, always writes timestamps instead of using sampling rate + """ + import warnings + + if starting_time is not None: + warnings.warn( + "The 'starting_time' parameter is deprecated and will be removed in June 2025.", + DeprecationWarning, + stacklevel=2, + ) + + if write_as != "raw": + warnings.warn( + "The 'write_as' parameter is deprecated and will be removed in June 2025.", + DeprecationWarning, + stacklevel=2, + ) + + if write_electrical_series is not True: + warnings.warn( + "The 'write_electrical_series' parameter is deprecated and will be removed in June 2025.", + DeprecationWarning, + stacklevel=2, + ) + + if stub_test or self.subset_channels is not None: + recording = self.subset_recording(stub_test=stub_test) + else: + recording = self.recording_extractor + + if metadata is None: + metadata = self.get_metadata() + + # Add devices + device_metadata = metadata.get("Devices", []) + for device in device_metadata: + if device["name"] not in nwbfile.devices: + nwbfile.create_device(**device) + + # Add analog and digital channels + if self.has_analog_channels: + self._add_analog_channels( + nwbfile=nwbfile, + recording=recording, + iterator_type=iterator_type, + iterator_opts=iterator_opts, + always_write_timestamps=always_write_timestamps, + ) + + if self.has_digital_channels: + self._add_digital_channels(nwbfile=nwbfile) + + def _add_analog_channels( + self, + nwbfile: NWBFile, + recording, + iterator_type: Optional[str], + iterator_opts: Optional[dict], + always_write_timestamps: bool, + ): + """ + Add analog channels from the NIDQ board to the NWB file. + + Parameters + ---------- + nwbfile : NWBFile + The NWB file to add the analog channels to + recording : BaseRecording + The recording extractor containing the analog channels + iterator_type : Optional[str] + Type of iterator to use for data streaming + iterator_opts : Optional[dict] + Additional options for the iterator + always_write_timestamps : bool + If True, always writes timestamps instead of using sampling rate + """ + analog_recorder = recording.select_channels(channel_ids=self.analog_channel_ids) + channel_names = analog_recorder.get_property(key="channel_names") + segment_index = 0 + analog_data_iterator = _recording_traces_to_hdmf_iterator( + recording=analog_recorder, + segment_index=segment_index, + iterator_type=iterator_type, + iterator_opts=iterator_opts, + ) + + name = "TimeSeriesNIDQ" + description = f"Analog data from the NIDQ board. Channels are {channel_names} in that order." + time_series_kwargs = dict(name=name, data=analog_data_iterator, unit="a.u.", description=description) + + if always_write_timestamps: + timestamps = recording.get_times(segment_index=segment_index) + shifted_timestamps = timestamps + time_series_kwargs.update(timestamps=shifted_timestamps) + else: + recording_has_timestamps = recording.has_time_vector(segment_index=segment_index) + if recording_has_timestamps: + timestamps = recording.get_times(segment_index=segment_index) + rate = calculate_regular_series_rate(series=timestamps) + recording_t_start = timestamps[0] + else: + rate = recording.get_sampling_frequency() + recording_t_start = recording._recording_segments[segment_index].t_start or 0 + + if rate: + starting_time = float(recording_t_start) + time_series_kwargs.update(starting_time=starting_time, rate=recording.get_sampling_frequency()) + else: + shifted_timestamps = timestamps + time_series_kwargs.update(timestamps=shifted_timestamps) + + time_series = TimeSeries(**time_series_kwargs) + nwbfile.add_acquisition(time_series) + + def _add_digital_channels(self, nwbfile: NWBFile): + """ + Add digital channels from the NIDQ board to the NWB file as events. + + Parameters + ---------- + nwbfile : NWBFile + The NWB file to add the digital channels to + """ + from ndx_events import Events + + event_channels = self.event_extractor.channel_ids + for channel_id in event_channels: + event_times = self.event_extractor.get_event_times(channel_id=channel_id) + if len(event_times) > 0: + description = f"Events from channel {channel_id}" + events = Events(name=channel_id, timestamps=event_times, description=description) + nwbfile.add_acquisition(events) + def get_event_times_from_ttl(self, channel_name: str) -> np.ndarray: """ Return the start of event times from the rising part of TTL pulses on one of the NIDQ channels. diff --git a/tests/test_on_data/ecephys/test_aux_interfaces.py b/tests/test_on_data/ecephys/test_aux_interfaces.py deleted file mode 100644 index 7934e29a1..000000000 --- a/tests/test_on_data/ecephys/test_aux_interfaces.py +++ /dev/null @@ -1,100 +0,0 @@ -import unittest -from datetime import datetime - -import pytest -from parameterized import param, parameterized -from spikeinterface.core.testing import check_recordings_equal -from spikeinterface.extractors import NwbRecordingExtractor - -from neuroconv import NWBConverter -from neuroconv.datainterfaces import SpikeGLXNIDQInterface - -# enable to run locally in interactive mode -try: - from ..setup_paths import ECEPHY_DATA_PATH as DATA_PATH - from ..setup_paths import OUTPUT_PATH -except ImportError: - from setup_paths import ECEPHY_DATA_PATH as DATA_PATH - from setup_paths import OUTPUT_PATH - -if not DATA_PATH.exists(): - pytest.fail(f"No folder found in location: {DATA_PATH}!") - - -def custom_name_func(testcase_func, param_num, param): - interface_name = param.kwargs["data_interface"].__name__ - reduced_interface_name = interface_name.replace("Interface", "") - - return ( - f"{testcase_func.__name__}_{param_num}_" - f"{parameterized.to_safe_name(reduced_interface_name)}" - f"_{param.kwargs.get('case_name', '')}" - ) - - -class TestEcephysAuxNwbConversions(unittest.TestCase): - savedir = OUTPUT_PATH - - parameterized_aux_list = [ - param( - data_interface=SpikeGLXNIDQInterface, - interface_kwargs=dict(file_path=str(DATA_PATH / "spikeglx" / "Noise4Sam_g0" / "Noise4Sam_g0_t0.nidq.bin")), - case_name="load_sync_channel_False", - ), - param( - data_interface=SpikeGLXNIDQInterface, - interface_kwargs=dict( - file_path=str(DATA_PATH / "spikeglx" / "Noise4Sam_g0" / "Noise4Sam_g0_t0.nidq.bin"), - load_sync_channel=True, - ), - case_name="load_sync_channel_True", - ), - ] - - @parameterized.expand(input=parameterized_aux_list, name_func=custom_name_func) - def test_aux_recording_extractor_to_nwb(self, data_interface, interface_kwargs, case_name=""): - nwbfile_path = str(self.savedir / f"{data_interface.__name__}_{case_name}.nwb") - - class TestConverter(NWBConverter): - data_interface_classes = dict(TestAuxRecording=data_interface) - - converter = TestConverter(source_data=dict(TestAuxRecording=interface_kwargs)) - - for interface_kwarg in interface_kwargs: - if interface_kwarg in ["file_path", "folder_path"]: - self.assertIn( - member=interface_kwarg, container=converter.data_interface_objects["TestAuxRecording"].source_data - ) - - metadata = converter.get_metadata() - metadata["NWBFile"].update(session_start_time=datetime.now().astimezone()) - converter.run_conversion(nwbfile_path=nwbfile_path, overwrite=True, metadata=metadata) - recording = converter.data_interface_objects["TestAuxRecording"].recording_extractor - - electrical_series_name = metadata["Ecephys"][converter.data_interface_objects["TestAuxRecording"].es_key][ - "name" - ] - - # NWBRecordingExtractor on spikeinterface does not yet support loading data written from multiple segments. - if recording.get_num_segments() == 1: - # Spikeinterface behavior is to load the electrode table channel_name property as a channel_id - nwb_recording = NwbRecordingExtractor(file_path=nwbfile_path, electrical_series_name=electrical_series_name) - if "channel_name" in recording.get_property_keys(): - renamed_channel_ids = recording.get_property("channel_name") - else: - renamed_channel_ids = recording.get_channel_ids().astype("str") - recording = recording.channel_slice( - channel_ids=recording.get_channel_ids(), renamed_channel_ids=renamed_channel_ids - ) - - # Edge case that only occurs in testing; I think it's fixed in > 0.96.1 versions (unreleased as of 1/11/23) - # The NwbRecordingExtractor on spikeinterface experiences an issue when duplicated channel_ids - # are specified, which occurs during check_recordings_equal when there is only one channel - if nwb_recording.get_channel_ids()[0] != nwb_recording.get_channel_ids()[-1]: - check_recordings_equal(RX1=recording, RX2=nwb_recording, return_scaled=False) - if recording.has_scaled_traces() and nwb_recording.has_scaled_traces(): - check_recordings_equal(RX1=recording, RX2=nwb_recording, return_scaled=True) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_on_data/ecephys/test_nidq_interface.py b/tests/test_on_data/ecephys/test_nidq_interface.py new file mode 100644 index 000000000..198c2be5a --- /dev/null +++ b/tests/test_on_data/ecephys/test_nidq_interface.py @@ -0,0 +1,52 @@ +import pytest +from pynwb import NWBHDF5IO + +from neuroconv.datainterfaces import SpikeGLXNIDQInterface + +# enable to run locally in interactive mode +try: + from ..setup_paths import ECEPHY_DATA_PATH +except ImportError: + from setup_paths import ECEPHY_DATA_PATH + +if not ECEPHY_DATA_PATH.exists(): + pytest.fail(f"No folder found in location: {ECEPHY_DATA_PATH}!") + + +def test_nidq_interface_digital_data_only(tmp_path): + + nwbfile_path = tmp_path / "nidq_test_digital_only.nwb" + folder_path = ECEPHY_DATA_PATH / "spikeglx" / "DigitalChannelTest_g0" + interface = SpikeGLXNIDQInterface(folder_path=folder_path) + interface.run_conversion(nwbfile_path=nwbfile_path, overwrite=True) + + with NWBHDF5IO(nwbfile_path, "r") as io: + nwbfile = io.read() + assert len(nwbfile.acquisition) == 1 # Onlye one channel has data for this set + events = nwbfile.acquisition["nidq#XD0"] + assert events.name == "nidq#XD0" + assert events.timestamps.size == 326 + + assert len(nwbfile.devices) == 1 + + +def test_nidq_interface_analog_data_only(tmp_path): + + nwbfile_path = tmp_path / "nidq_test_analog_only.nwb" + folder_path = ECEPHY_DATA_PATH / "spikeglx" / "Noise4Sam_g0" + interface = SpikeGLXNIDQInterface(folder_path=folder_path) + interface.run_conversion(nwbfile_path=nwbfile_path, overwrite=True) + + with NWBHDF5IO(nwbfile_path, "r") as io: + nwbfile = io.read() + assert len(nwbfile.acquisition) == 1 # The time series object + time_series = nwbfile.acquisition["TimeSeriesNIDQ"] + assert time_series.name == "TimeSeriesNIDQ" + expected_description = "Analog data from the NIDQ board. Channels are ['XA0' 'XA1' 'XA2' 'XA3' 'XA4' 'XA5' 'XA6' 'XA7'] in that order." + assert time_series.description == expected_description + number_of_samples = time_series.data.shape[0] + assert number_of_samples == 60_864 + number_of_channels = time_series.data.shape[1] + assert number_of_channels == 8 + + assert len(nwbfile.devices) == 1