diff --git a/movement/io/_nwb.py b/movement/io/_nwb.py new file mode 100644 index 00000000..4122cbb3 --- /dev/null +++ b/movement/io/_nwb.py @@ -0,0 +1,174 @@ +"""Functions to convert between movement poses datasets and NWB files. + +The pose tracks in NWB files are formatted according to the ``ndx-pose`` +NWB extension, see https://github.com/rly/ndx-pose. +""" + +import logging + +import ndx_pose +import pynwb +import xarray as xr + +from movement.utils.logging import log_warning + +logger = logging.getLogger(__name__) + +# Default keyword arguments for Skeletons, +# PoseEstimation and PoseEstimationSeries objects +SKELETON_KWARGS = dict(edges=None) +POSE_ESTIMATION_SERIES_KWARGS = dict( + reference_frame="(0,0,0) corresponds to ...", + confidence_definition=None, + conversion=1.0, + resolution=-1.0, + offset=0.0, + starting_time=None, + comments="no comments", + description="no description", + control=None, + control_description=None, +) +POSE_ESTIMATION_KWARGS = dict( + original_videos=None, + labeled_videos=None, + dimensions=None, + devices=None, + scorer=None, + source_software_version=None, +) + + +def _merge_kwargs(defaults, overrides): + return {**defaults, **(overrides or {})} + + +def _ds_to_pose_and_skeleton_objects( + ds: xr.Dataset, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeleton_kwargs: dict | None = None, +) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: + """Create PoseEstimation and Skeletons objects from a ``movement`` dataset. + + Parameters + ---------- + ds : xarray.Dataset + A single-individual ``movement`` poses dataset. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. + See ``ndx_pose``, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ``ndx_pose``, by default None + skeleton_kwargs : dict, optional + Skeleton keyword arguments. See ``ndx_pose``, by default None + + Returns + ------- + pose_estimation : list[ndx_pose.PoseEstimation] + List of PoseEstimation objects + skeletons : ndx_pose.Skeletons + Skeletons object containing all skeletons + + """ + # Use default kwargs, but updated with any user-provided kwargs + pose_estimation_series_kwargs = _merge_kwargs( + POSE_ESTIMATION_SERIES_KWARGS, pose_estimation_series_kwargs + ) + pose_estimation_kwargs = _merge_kwargs( + POSE_ESTIMATION_KWARGS, pose_estimation_kwargs + ) + skeleton_kwargs = _merge_kwargs(SKELETON_KWARGS, skeleton_kwargs) + + # Extract individual name + individual = ds.individuals.values.item() + + # Create a PoseEstimationSeries object for each keypoint + pose_estimation_series = [] + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel(keypoints=keypoint).position.to_numpy(), + confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel(keypoints=keypoint).time.to_numpy(), + **pose_estimation_series_kwargs, + ) + ) + # Create a Skeleton object for the chosen individual + skeleton_list = [ + ndx_pose.Skeleton( + name=f"{individual}_skeleton", + nodes=ds.keypoints.to_numpy().tolist(), + **skeleton_kwargs, + ) + ] + + # Group all PoseEstimationSeries into a PoseEstimation object + bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) + description = ( + f"Estimated positions of {bodyparts_str} for " + f"{individual} using {ds.source_software}." + ) + pose_estimation = [ + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, + ) + ] + # Create a Skeletons object + skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) + + return pose_estimation, skeletons + + +def _write_behavior_processing_module( + nwb_file: pynwb.NWBFile, + pose_estimation: ndx_pose.PoseEstimation, + skeletons: ndx_pose.Skeletons, +) -> None: + """Write behaviour processing data to an NWB file. + + PoseEstimation or Skeletons objects will be written to the NWB file's + "behavior" processing module, formatted according to the ``ndx-pose`` NWB + extension. If the module does not exist, it will be created. + Data will not overwrite any existing objects in the NWB file. + + Parameters + ---------- + nwb_file : pynwb.NWBFile + The NWB file object to which the data will be added. + pose_estimation : ndx_pose.PoseEstimation + PoseEstimation object containing the pose data for an individual. + skeletons : ndx_pose.Skeletons + Skeletons object containing the skeleton data for an individual. + + """ + try: + behavior_pm = nwb_file.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + logger.debug("Created behavior processing module in NWB file.") + except ValueError: + logger.debug( + "Data will be added to existing behavior processing module." + ) + behavior_pm = nwb_file.processing["behavior"] + + try: + behavior_pm.add(skeletons) + logger.info("Added Skeletons object to NWB file.") + except ValueError: + log_warning("Skeletons object already exists. Skipping...") + + try: + behavior_pm.add(pose_estimation) + logger.info("Added PoseEstimation object to NWB file.") + except ValueError: + log_warning("PoseEstimation object already exists. Skipping...") diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 9c6c8216..41bfa807 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -7,6 +7,7 @@ import h5py import numpy as np import pandas as pd +import pynwb import xarray as xr from sleap_io.io.slp import read_labels from sleap_io.model.labels import Labels @@ -97,7 +98,11 @@ def from_numpy( def from_file( file_path: Path | str, source_software: Literal[ - "DeepLabCut", "SLEAP", "LightningPose", "Anipose" + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "NWB", ], fps: float | None = None, **kwargs, @@ -112,11 +117,14 @@ def from_file( ``from_slp_file()`` or ``from_lp_file()`` functions. One of these these functions will be called internally, based on the value of ``source_software``. - source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose" + source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \ + "NWB"} The source software of the file. fps : float, optional The number of frames per second in the video. If None (default), the ``time`` coordinates will be in frame numbers. + The fps argument is ignored when source_software is "NWB", as the + frame rate will be estimated from timestamps in the file. **kwargs : dict, optional Additional keyword arguments to pass to the software-specific loading functions that are listed under "See Also". @@ -133,6 +141,7 @@ def from_file( movement.io.load_poses.from_sleap_file movement.io.load_poses.from_lp_file movement.io.load_poses.from_anipose_file + movement.io.load_poses.from_nwb_file Examples -------- @@ -142,6 +151,12 @@ def from_file( ... ) """ + if source_software == "NWB" and fps is not None: + log_warning( + "The fps argument is ignored when loading from an NWB file. " + "The frame rate will be estimated from timestamps in the file." + ) + if source_software == "DeepLabCut": return from_dlc_file(file_path, fps) elif source_software == "SLEAP": @@ -150,6 +165,8 @@ def from_file( return from_lp_file(file_path, fps) elif source_software == "Anipose": return from_anipose_file(file_path, fps, **kwargs) + elif source_software == "NWB": + return from_nwb_file(file_path, **kwargs) else: raise log_error( ValueError, f"Unsupported source software: {source_software}" @@ -825,3 +842,127 @@ def from_anipose_file( return from_anipose_style_df( anipose_df, fps=fps, individual_name=individual_name ) + + +def from_nwb_file( + file: str | Path | pynwb.NWBFile, + key_name: str = "PoseEstimation", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an NWB file. + + The input can be a path to an NWB file on disk or an open NWB file object. + The data will be extracted from the NWB file's behavior processing module, + which is assumed to contain a ``PoseEstimation`` object formatted according + to the ``ndx-pose`` NWB extension [1]_. + + Parameters + ---------- + file : str | Path | NWBFile + Path to the NWB file on disk (ending in ".nwb"), + or an open NWBFile object. + key_name: str, optional + Name of the ``PoseEstimation`` object in the NWB "behavior" + processing module, by default "PoseEstimation". + + Returns + ------- + movement_ds : xr.Dataset + A single-individual ``movement`` dataset containing the pose tracks, + confidence scores, and associated metadata. + + References + ---------- + .. [1] https://github.com/rly/ndx-pose + + Examples + -------- + Open an NWB file and load pose tracks from the file object: + + >>> import pynwb + >>> with pynwb.NWBHDF5IO("path/to/file.nwb", mode="r") as io: + ... nwb_file = io.read() + ... ds = load_poses.from_nwb_file(nwb_file) + + Directly load pose tracks from an NWB file on disk: + + >>> from movement.io import load_poses + >>> ds = load_poses.from_nwb_file("path/to/file.nwb") + + Load two single-individual datasets from two NWB files and merge them + into a multi-individual dataset: + + >>> ds_singles = [ + ... load_poses.from_nwb_file(f) for f in ["id1.nwb", "id2.nwb"] + ... ] + >>> ds_multi = xr.merge(datasets) + + """ + if isinstance(file, str | Path): + valid_file = ValidFile( + file, expected_permission="r", expected_suffix=[".nwb"] + ) + with pynwb.NWBHDF5IO(valid_file.path, mode="r") as io: + nwb_file = io.read() + ds = _ds_from_nwb_object(nwb_file, key_name=key_name) + ds.attrs["source_file"] = valid_file.path + elif isinstance(file, pynwb.NWBFile): + ds = _ds_from_nwb_object(file, key_name=key_name) + ds.attrs["source_file"] = None + else: + raise log_error( + TypeError, + "Expected file to be one of following types: str, Path, " + f"pynwb.NWBFile. Got {type(file)} instead.", + ) + return ds + + +def _ds_from_nwb_object( + nwb_file: pynwb.NWBFile, + key_name: str = "PoseEstimation", +) -> xr.Dataset: + """Extract a ``movement`` poses dataset from an open NWB file object. + + Parameters + ---------- + nwb_file : pynwb.NWBFile + An open NWB file object. + key_name: str + Name of the ``PoseEstimation`` object in the NWB "behavior" + processing module, by default "PoseEstimation". + + Returns + ------- + movement_ds : xr.Dataset + A single-individual ``movement`` poses dataset + + """ + pose_estimation = nwb_file.processing["behavior"][key_name] + source_software = pose_estimation.fields["source_software"] + pose_estimation_series = pose_estimation.fields["pose_estimation_series"] + single_keypoint_datasets = [] + for keypoint, pse in pose_estimation_series.items(): + # Extract position and confidence data for each keypoint + position_data = np.asarray(pse.data) # shape: (n_time, n_space) + confidence_data = ( # shape: (n_time,) + np.asarray(pse.confidence) + if getattr(pse, "confidence", None) is not None + else np.full(position_data.shape[0], np.nan) + ) + + # Compute fps from the time differences between timestamps + fps = np.nanmedian(1 / np.diff(pse.timestamps)) + + single_keypoint_datasets.append( + # create movement dataset with 1 keypoint and 1 individual + from_numpy( + position_data[:, :, np.newaxis, np.newaxis], + confidence_data[:, np.newaxis, np.newaxis], + individual_names=[nwb_file.identifier], + keypoint_names=[keypoint], + fps=round(fps, 6), + source_software=source_software, + ) + ) + + return xr.merge(single_keypoint_datasets) diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 6a585457..e2eb8a2f 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -7,8 +7,13 @@ import h5py import numpy as np import pandas as pd +import pynwb import xarray as xr +from movement.io._nwb import ( + _ds_to_pose_and_skeleton_objects, + _write_behavior_processing_module, +) from movement.utils.logging import log_error from movement.validators.datasets import ValidPosesDataset from movement.validators.files import ValidFile @@ -358,6 +363,103 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None: logger.info(f"Saved poses dataset to {file.path}.") +def to_nwb_file( + ds: xr.Dataset, + nwb_files: pynwb.NWBFile | list[pynwb.NWBFile], + *, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeletons_kwargs: dict | None = None, +) -> None: + """Save a ``movement`` dataset to one or more open NWB file objects. + + The data will be written to the NWB file(s) in the "behavior" processing + module, formatted according to the ``ndx-pose`` NWB extension [1]_. + Each individual in the dataset will be written to a separate file object, + as required by the NWB format. Note that the NWBFile object(s) are not + automatically saved to disk. + + Parameters + ---------- + ds : xr.Dataset + ``movement`` poses dataset containing the data to be converted to NWB + nwb_files : list[pynwb.NWBFile] | pynwb.NWBFile + An NWBFile object or a list of such objects to which the data + will be added. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. + See ``ndx-pose``, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ``ndx-pose``, by default None + skeletons_kwargs : dict, optional + Skeleton keyword arguments. See ``ndx-pose``, by default None + + Raises + ------ + ValueError + If the number of NWBFiles is not equal to the number of individuals. + + Notes + ----- + The data will not overwrite any existing PoseEstimation or Skeletons + objects in the NWB file(s). If the objects already exist, the function + will skip adding them. + + References + ---------- + .. [1] https://github.com/rly/ndx-pose + + Examples + -------- + Let's load a sample dataset containing pose tracks from two mice: + + >>> from movement import sample_data + >>> ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + + We will create two NWBFiles, one for each individual, save the pose + data to them, and then write the files to disk: + + >>> import datetime as dt + >>> from pynwb import NWBFile, NWBHDF5IO + >>> from movement.io import save_poses + >>> nwb_files = [ + ... NWBFile( + ... session_description="session_description", + ... identifier=id, + ... session_start_time=dt.datetime.now(dt.timezone.utc), + ... ) + ... for id in ds.individuals.values + ... ] + >>> save_poses.to_nwb_file(ds, nwb_files) + >>> for file in nwb_files: + ... with NWBHDF5IO(f"{file.identifier}.nwb", "w") as io: + ... io.write(file) + + """ + if isinstance(nwb_files, pynwb.NWBFile): + nwb_files = [nwb_files] + + if len(nwb_files) != len(ds.individuals): + raise log_error( + ValueError, + "Number of NWBFile objects must be equal to the number of " + "individuals, as NWB requires one file per individual (subject).", + ) + + for nwb_file, individual in zip( + nwb_files, ds.individuals.values, strict=False + ): + pose_estimation, skeletons = _ds_to_pose_and_skeleton_objects( + ds.sel(individuals=individual), + pose_estimation_series_kwargs, + pose_estimation_kwargs, + skeletons_kwargs, + ) + _write_behavior_processing_module( + nwb_file, pose_estimation[0], skeletons + ) + + def _remove_unoccupied_tracks(ds: xr.Dataset): """Remove tracks that are completely unoccupied from the dataset. diff --git a/pyproject.toml b/pyproject.toml index 27348c29..fb64d957 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "sleap-io", "xarray[accel,viz]", "PyYAML", + "pynwb", + "ndx-pose>=0.2.1", ] classifiers = [ diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py new file mode 100644 index 00000000..5d5b2358 --- /dev/null +++ b/tests/test_unit/test_nwb.py @@ -0,0 +1,210 @@ +import datetime + +import ndx_pose +import numpy as np +import xarray as xr +from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons +from pynwb import NWBHDF5IO, NWBFile +from pynwb.file import Subject + +from movement import sample_data +from movement.io._nwb import _ds_to_pose_and_skeleton_objects +from movement.io.load_poses import from_nwb_file +from movement.io.save_poses import to_nwb_file + + +def test_ds_to_pose_and_skeleton_objects(): + # Create a sample dataset + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + + # Call the function + pose_estimation, skeletons = _ds_to_pose_and_skeleton_objects( + ds.sel(individuals="individual1"), + pose_estimation_series_kwargs=None, + pose_estimation_kwargs=None, + skeleton_kwargs=None, + ) + + # Assert the output types + assert isinstance(pose_estimation, list) + assert isinstance(skeletons, ndx_pose.Skeletons) + + # Assert the length of pose_estimation list (n_individuals) + assert len(pose_estimation) == 1 + + # Assert the length of pose_estimation_series list (n_keypoints) + assert len(pose_estimation[0].pose_estimation_series) == 12 + + # Assert the name of the first PoseEstimationSeries (first keypoint) + assert "snout" in pose_estimation[0].pose_estimation_series + + # Assert the name of the Skeleton (individual1_skeleton) + assert "individual1_skeleton" in skeletons.skeletons + + +def create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" +): + rng = np.random.default_rng(42) + data = rng.random((n_time, n_dims)) # num_frames x n_space_dims (2 or 3) + # Create an array of timestamps in seconds, assuming fps=10.0 + timestamps = np.arange(n_time) / 10.0 + confidence = np.ones((n_time,)) # a confidence value for every frame + reference_frame = "(0,0,0) corresponds to ..." + confidence_definition = "Softmax output of the deep neural network." + + return PoseEstimationSeries( + name=keypoint, + description="Marker placed around fingers of front left paw.", + data=data, + unit="pixels", + reference_frame=reference_frame, + timestamps=timestamps, + confidence=confidence, + confidence_definition=confidence_definition, + ) + + +def test_save_poses_to_single_nwb_file(): + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + to_nwb_file(ds.sel(individuals=["individual1"]), nwbfile_individual1) + assert ( + "PoseEstimation" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + assert ( + "Skeletons" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + + +def test_save_poses_to_multiple_nwb_files(): + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=session_start_time, + ) + + nwbfiles = [nwbfile_individual1, nwbfile_individual2] + to_nwb_file(ds, nwbfiles) + + +def create_test_pose_nwb(identifier="subject1") -> NWBFile: + # initialize an NWBFile object + nwb_file = NWBFile( + session_description="session_description", + identifier=identifier, + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) + + # add a subject to the NWB file + subject = Subject(subject_id=identifier, species="Mus musculus") + nwb_file.subject = subject + + # create a skeleton object + skeleton = Skeleton( + name="subject1_skeleton", + nodes=["front_left_paw", "body", "front_right_paw"], + edges=np.array([[0, 1], [1, 2]], dtype="uint8"), + subject=subject, + ) + skeletons = Skeletons(skeletons=[skeleton]) + + # create a device for the camera + camera1 = nwb_file.create_device( + name="camera1", + description="camera for recording behavior", + manufacturer="my manufacturer", + ) + + n_time = 100 + n_dims = 2 # 2D data + front_left_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_left_paw" + ) + + body = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="body" + ) + front_right_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_right_paw" + ) + + # store all PoseEstimationSeries in a list + pose_estimation_series = [front_left_paw, body, front_right_paw] + + pose_estimation = PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=( + "Estimated positions of front paws of subject1 using DeepLabCut." + ), + original_videos=["path/to/camera1.mp4"], + labeled_videos=["path/to/camera1_labeled.mp4"], + dimensions=np.array( + [[640, 480]], dtype="uint16" + ), # pixel dimensions of the video + devices=[camera1], + scorer="DLC_resnet50_openfieldOct30shuffle1_1600", + source_software="DeepLabCut", + source_software_version="2.3.8", + skeleton=skeleton, # link to the skeleton object + ) + + behavior_pm = nwb_file.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + behavior_pm.add(skeletons) + behavior_pm.add(pose_estimation) + + return nwb_file + + +def test_load_poses_from_nwb_file(tmp_path): + nwb_file = create_test_pose_nwb() + + # write the NWBFile to disk (temporary file) + file_path = tmp_path / "test_pose.nwb" + with NWBHDF5IO(file_path, mode="w") as io: + io.write(nwb_file) + + # Read the dataset from the file path + ds_from_file_path = from_nwb_file(file_path) + + # Assert the dimensions and attributes of the dataset + assert ds_from_file_path.sizes == { + "time": 100, + "individuals": 1, + "keypoints": 3, + "space": 2, + } + assert ds_from_file_path.attrs == { + "ds_type": "poses", + "fps": 10.0, + "time_unit": "seconds", + "source_software": "DeepLabCut", + "source_file": file_path, + } + + # Read the same dataset from an open NWB file + with NWBHDF5IO(file_path, mode="r") as io: + nwb_file = io.read() + ds_from_open_file = from_nwb_file(nwb_file) + # Check that it's identical to the dataset read from the file path + # except for the "source_file" attribute + ds_from_file_path.attrs["source_file"] = None + xr.testing.assert_identical(ds_from_file_path, ds_from_open_file)