Skip to content

Commit

Permalink
add tests and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Dec 10, 2024
1 parent a965d23 commit e49618e
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 131 deletions.
232 changes: 201 additions & 31 deletions src/neuroconv/datainterfaces/ecephys/spikeglx/spikeglxnidqinterface.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,35 @@
from pathlib import Path
from typing import Optional
from typing import Literal, Optional

import numpy as np
from pydantic import ConfigDict, DirectoryPath, FilePath, validate_call
from pynwb import NWBFile
from pynwb.base import TimeSeries

from .spikeglx_utils import get_session_start_time
from ..baserecordingextractorinterface import BaseRecordingExtractorInterface
from ....basedatainterface import BaseDataInterface
from ....tools.signal_processing import get_rising_frames_from_ttl
from ....utils import get_json_schema_from_method_signature
from ....tools.spikeinterface.spikeinterface import _recording_traces_to_hdmf_iterator
from ....utils import (
calculate_regular_series_rate,
get_json_schema_from_method_signature,
)


class SpikeGLXNIDQInterface(BaseRecordingExtractorInterface):
class SpikeGLXNIDQInterface(BaseDataInterface):
"""Primary data interface class for converting the high-pass (ap) SpikeGLX format."""

display_name = "NIDQ Recording"
keywords = BaseRecordingExtractorInterface.keywords + ("Neuropixels",)
keywords = ("Neuropixels", "nidq", "NIDQ", "SpikeGLX")
associated_suffixes = (".nidq", ".meta", ".bin")
info = "Interface for NIDQ board recording data."

ExtractorName = "SpikeGLXRecordingExtractor"
stream_id = "nidq"

@classmethod
def get_source_schema(cls) -> dict:
source_schema = get_json_schema_from_method_signature(method=cls.__init__, exclude=["x_pitch", "y_pitch"])
source_schema["properties"]["file_path"]["description"] = "Path to SpikeGLX .nidq file."
return source_schema

def _source_data_to_extractor_kwargs(self, source_data: dict) -> dict:

extractor_kwargs = source_data.copy()
extractor_kwargs["folder_path"] = self.folder_path
extractor_kwargs["stream_id"] = self.stream_id
return extractor_kwargs

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
Expand Down Expand Up @@ -72,18 +68,35 @@ def __init__(
if folder_path is not None:
self.folder_path = Path(folder_path)

from spikeinterface.extractors import SpikeGLXRecordingExtractor

self.recording_extractor = SpikeGLXRecordingExtractor(
folder_path=self.folder_path,
stream_id="nidq",
all_annotations=True,
)

channel_ids = self.recording_extractor.get_channel_ids()
analog_channel_signatures = ["XA", "MA"]
self.analog_channel_ids = [ch for ch in channel_ids if "XA" in ch or "MA" in ch]
self.has_analog_channels = len(self.analog_channel_ids) > 0
self.has_digital_channels = len(self.analog_channel_ids) < len(channel_ids)
if self.has_digital_channels:
from spikeinterface.extractors import SpikeGLXEventExtractor

self.event_extractor = SpikeGLXEventExtractor(folder_path=self.folder_path)

super().__init__(
verbose=verbose,
load_sync_channel=load_sync_channel,
es_key=es_key,
folder_path=self.folder_path,
file_path=file_path,
)
self.source_data.update(file_path=str(file_path))

self.recording_extractor.set_property(
key="group_name", values=["NIDQChannelGroup"] * self.recording_extractor.get_num_channels()
)
self.subset_channels = None

signal_info_key = (0, self.stream_id) # Key format is (segment_index, stream_id)
signal_info_key = (0, "nidq") # Key format is (segment_index, stream_id)
self._signals_info_dict = self.recording_extractor.neo_reader.signals_info_dict[signal_info_key]
self.meta = self._signals_info_dict["meta"]

Expand All @@ -101,24 +114,181 @@ def get_metadata(self) -> dict:
manufacturer="National Instruments",
)

# Add groups metadata
metadata["Ecephys"]["Device"] = [device]
metadata["Devices"] = [device]

metadata["Ecephys"]["ElectrodeGroup"][0].update(
name="NIDQChannelGroup", description="A group representing the NIDQ channels.", device=device["name"]
)
metadata["Ecephys"]["Electrodes"] = [
dict(name="group_name", description="Name of the ElectrodeGroup this electrode is a part of."),
]
metadata["Ecephys"]["ElectricalSeriesNIDQ"][
"description"
] = "Raw acquisition traces from the NIDQ (.nidq.bin) channels."
return metadata

def get_channel_names(self) -> list[str]:
"""Return a list of channel names as set in the recording extractor."""
return list(self.recording_extractor.get_channel_ids())

def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: Optional[dict] = None,
stub_test: bool = False,
starting_time: Optional[float] = None,
write_as: Literal["raw", "lfp", "processed"] = "raw",
write_electrical_series: bool = True,
iterator_type: Optional[str] = "v2",
iterator_opts: Optional[dict] = None,
always_write_timestamps: bool = False,
):
"""
Add NIDQ board data to an NWB file, including both analog and digital channels if present.
Parameters
----------
nwbfile : NWBFile
The NWB file to which the NIDQ data will be added
metadata : Optional[dict], default: None
Metadata dictionary with device information. If None, uses default metadata
stub_test : bool, default: False
If True, only writes a small amount of data for testing
starting_time : Optional[float], default: None
DEPRECATED: Will be removed in June 2025. Starting time offset for the TimeSeries
write_as : Literal["raw", "lfp", "processed"], default: "raw"
DEPRECATED: Will be removed in June 2025. Specifies how to write the data
write_electrical_series : bool, default: True
DEPRECATED: Will be removed in June 2025. Whether to write electrical series data
iterator_type : Optional[str], default: "v2"
Type of iterator to use for data streaming
iterator_opts : Optional[dict], default: None
Additional options for the iterator
always_write_timestamps : bool, default: False
If True, always writes timestamps instead of using sampling rate
"""
import warnings

if starting_time is not None:
warnings.warn(
"The 'starting_time' parameter is deprecated and will be removed in June 2025.",
DeprecationWarning,
stacklevel=2,
)

if write_as != "raw":
warnings.warn(
"The 'write_as' parameter is deprecated and will be removed in June 2025.",
DeprecationWarning,
stacklevel=2,
)

if write_electrical_series is not True:
warnings.warn(
"The 'write_electrical_series' parameter is deprecated and will be removed in June 2025.",
DeprecationWarning,
stacklevel=2,
)

if stub_test or self.subset_channels is not None:
recording = self.subset_recording(stub_test=stub_test)
else:
recording = self.recording_extractor

if metadata is None:
metadata = self.get_metadata()

# Add devices
device_metadata = metadata.get("Devices", [])
for device in device_metadata:
if device["name"] not in nwbfile.devices:
nwbfile.create_device(**device)

# Add analog and digital channels
if self.has_analog_channels:
self._add_analog_channels(
nwbfile=nwbfile,
recording=recording,
iterator_type=iterator_type,
iterator_opts=iterator_opts,
always_write_timestamps=always_write_timestamps,
)

if self.has_digital_channels:
self._add_digital_channels(nwbfile=nwbfile)

def _add_analog_channels(
self,
nwbfile: NWBFile,
recording,
iterator_type: Optional[str],
iterator_opts: Optional[dict],
always_write_timestamps: bool,
):
"""
Add analog channels from the NIDQ board to the NWB file.
Parameters
----------
nwbfile : NWBFile
The NWB file to add the analog channels to
recording : BaseRecording
The recording extractor containing the analog channels
iterator_type : Optional[str]
Type of iterator to use for data streaming
iterator_opts : Optional[dict]
Additional options for the iterator
always_write_timestamps : bool
If True, always writes timestamps instead of using sampling rate
"""
analog_recorder = recording.select_channels(channel_ids=self.analog_channel_ids)
channel_names = analog_recorder.get_property(key="channel_names")
segment_index = 0
analog_data_iterator = _recording_traces_to_hdmf_iterator(
recording=analog_recorder,
segment_index=segment_index,
iterator_type=iterator_type,
iterator_opts=iterator_opts,
)

name = "TimeSeriesNIDQ"
description = f"Analog data from the NIDQ board. Channels are {channel_names} in that order."
time_series_kwargs = dict(name=name, data=analog_data_iterator, unit="a.u.", description=description)

if always_write_timestamps:
timestamps = recording.get_times(segment_index=segment_index)
shifted_timestamps = timestamps
time_series_kwargs.update(timestamps=shifted_timestamps)
else:
recording_has_timestamps = recording.has_time_vector(segment_index=segment_index)
if recording_has_timestamps:
timestamps = recording.get_times(segment_index=segment_index)
rate = calculate_regular_series_rate(series=timestamps)
recording_t_start = timestamps[0]
else:
rate = recording.get_sampling_frequency()
recording_t_start = recording._recording_segments[segment_index].t_start or 0

if rate:
starting_time = float(recording_t_start)
time_series_kwargs.update(starting_time=starting_time, rate=recording.get_sampling_frequency())
else:
shifted_timestamps = timestamps
time_series_kwargs.update(timestamps=shifted_timestamps)

time_series = TimeSeries(**time_series_kwargs)
nwbfile.add_acquisition(time_series)

def _add_digital_channels(self, nwbfile: NWBFile):
"""
Add digital channels from the NIDQ board to the NWB file as events.
Parameters
----------
nwbfile : NWBFile
The NWB file to add the digital channels to
"""
from ndx_events import Events

event_channels = self.event_extractor.channel_ids
for channel_id in event_channels:
event_times = self.event_extractor.get_event_times(channel_id=channel_id)
if len(event_times) > 0:
description = f"Events from channel {channel_id}"
events = Events(name=channel_id, timestamps=event_times, description=description)
nwbfile.add_acquisition(events)

def get_event_times_from_ttl(self, channel_name: str) -> np.ndarray:
"""
Return the start of event times from the rising part of TTL pulses on one of the NIDQ channels.
Expand Down
100 changes: 0 additions & 100 deletions tests/test_on_data/ecephys/test_aux_interfaces.py

This file was deleted.

Loading

0 comments on commit e49618e

Please sign in to comment.