Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I/O support for the ndx-pose NWB extension: take 2 #360

Draft
wants to merge 52 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
a360d27
Create nwb_export.py
edeno Apr 18, 2024
44edddf
NWB requires one file per individual
edeno Apr 18, 2024
927c655
Add script
edeno Apr 19, 2024
7ccc5b6
Remove import error handling
edeno Apr 19, 2024
91eb4e5
Add nwb optional dependencies
edeno Apr 19, 2024
58b80c1
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
32692c7
Add example docstring
edeno Apr 19, 2024
82e62ce
Rename to fit module naming pattern
edeno Apr 19, 2024
11c1317
Add import from nwb
edeno Apr 19, 2024
1b3253d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
e473eea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
393ebbb
Apply suggestions from code review
edeno Jun 8, 2024
e7d2c68
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
43c1a73
Cleanup of docstrings and variable names from code review
edeno Jun 8, 2024
0af1660
Rename function for clarity
edeno Jun 8, 2024
5a14ad1
Update with example converting back to movement
edeno Jun 8, 2024
54fe8f3
Add file validation and handling for single path
edeno Jun 8, 2024
510c8a9
Add preliminary tests
edeno Jun 8, 2024
48c1d94
Convert to numpy array
edeno Jun 9, 2024
9b81b49
Handle lack of confidence
edeno Jun 9, 2024
8a8f578
Display xarray
edeno Jun 9, 2024
9ebfeb9
Refactor tests
edeno Jun 9, 2024
58bdae2
Create nwb_export.py
edeno Apr 18, 2024
57841fa
NWB requires one file per individual
edeno Apr 18, 2024
d2fd039
Remove import error handling
edeno Apr 19, 2024
2c3fb89
Add nwb optional dependencies
edeno Apr 19, 2024
fd261e4
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
001ea4b
Rename to fit module naming pattern
edeno Apr 19, 2024
08e9e33
Add import from nwb
edeno Apr 19, 2024
ee5cacb
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
b625f8a
Add file validation and handling for single path
edeno Jun 8, 2024
f43a0ad
Convert to numpy array
edeno Jun 9, 2024
216b01d
fix logging module import
niksirbi Nov 29, 2024
f841f30
constrained pynwb>=0.2.1
niksirbi Nov 29, 2024
f065f3f
fixed existing unit tests
niksirbi Nov 29, 2024
ddeeea6
add key_name argument to convert_nwb_to_movement
niksirbi Nov 29, 2024
060e051
tests should only create temp file
niksirbi Dec 11, 2024
621e79c
use Generator instead of legacy np.random.random
niksirbi Dec 11, 2024
63eafb3
reorder dims and use from_numpy for creating movement ds
niksirbi Dec 12, 2024
e137cb3
define default nwb kwargs as constants
niksirbi Dec 12, 2024
b06ccbd
renamed and reformatted `add_movement_dataset_to_nwb` to `ds_to_nwb`
niksirbi Dec 12, 2024
51bb1af
Expanded module-level docstring
niksirbi Dec 12, 2024
179963d
use individual instead of subject
niksirbi Dec 12, 2024
cfeaca8
refactored functions for loading ds from nwb
niksirbi Dec 12, 2024
3b24d6d
make mypy happy with numpy typing
niksirbi Dec 18, 2024
c8d1b82
rename nwb example
niksirbi Dec 18, 2024
0e4849a
renamed private func for creating pose estimation and skeletons objects
niksirbi Dec 18, 2024
1e5bfef
incorporate NWB loading into load_poses module
niksirbi Dec 18, 2024
fa0c1db
incorporate NWB saving function into save_poses module
niksirbi Dec 18, 2024
84d7ead
simplified private nwb functions
niksirbi Dec 18, 2024
7adba94
provide examples in docstrings instead of sphinx gallery example
niksirbi Dec 18, 2024
b9305e4
fix docstring syntax error
niksirbi Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions movement/io/_nwb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Functions to convert between movement poses datasets and NWB files.

The pose tracks in NWB files are formatted according to the ``ndx-pose``
NWB extension, see https://github.com/rly/ndx-pose.
"""

import logging

import ndx_pose
import pynwb
import xarray as xr

from movement.utils.logging import log_warning

logger = logging.getLogger(__name__)

# Default keyword arguments for Skeletons,
# PoseEstimation and PoseEstimationSeries objects
SKELETON_KWARGS = dict(edges=None)
POSE_ESTIMATION_SERIES_KWARGS = dict(
reference_frame="(0,0,0) corresponds to ...",
confidence_definition=None,
conversion=1.0,
resolution=-1.0,
offset=0.0,
starting_time=None,
comments="no comments",
description="no description",
control=None,
control_description=None,
)
POSE_ESTIMATION_KWARGS = dict(
original_videos=None,
labeled_videos=None,
dimensions=None,
devices=None,
scorer=None,
source_software_version=None,
)


def _merge_kwargs(defaults, overrides):
return {**defaults, **(overrides or {})}


def _ds_to_pose_and_skeleton_objects(
ds: xr.Dataset,
pose_estimation_series_kwargs: dict | None = None,
pose_estimation_kwargs: dict | None = None,
skeleton_kwargs: dict | None = None,
) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]:
"""Create PoseEstimation and Skeletons objects from a ``movement`` dataset.

Parameters
----------
ds : xarray.Dataset
A single-individual ``movement`` poses dataset.
pose_estimation_series_kwargs : dict, optional
PoseEstimationSeries keyword arguments.
See ``ndx_pose``, by default None
pose_estimation_kwargs : dict, optional
PoseEstimation keyword arguments. See ``ndx_pose``, by default None
skeleton_kwargs : dict, optional
Skeleton keyword arguments. See ``ndx_pose``, by default None

Returns
-------
pose_estimation : list[ndx_pose.PoseEstimation]
List of PoseEstimation objects
skeletons : ndx_pose.Skeletons
Skeletons object containing all skeletons

"""
# Use default kwargs, but updated with any user-provided kwargs
pose_estimation_series_kwargs = _merge_kwargs(
POSE_ESTIMATION_SERIES_KWARGS, pose_estimation_series_kwargs
)
pose_estimation_kwargs = _merge_kwargs(
POSE_ESTIMATION_KWARGS, pose_estimation_kwargs
)
skeleton_kwargs = _merge_kwargs(SKELETON_KWARGS, skeleton_kwargs)

# Extract individual name
individual = ds.individuals.values.item()

# Create a PoseEstimationSeries object for each keypoint
pose_estimation_series = []
for keypoint in ds.keypoints.to_numpy():
pose_estimation_series.append(
ndx_pose.PoseEstimationSeries(
name=keypoint,
data=ds.sel(keypoints=keypoint).position.to_numpy(),
confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(),
unit="pixels",
timestamps=ds.sel(keypoints=keypoint).time.to_numpy(),
**pose_estimation_series_kwargs,
)
)
# Create a Skeleton object for the chosen individual
skeleton_list = [
ndx_pose.Skeleton(
name=f"{individual}_skeleton",
nodes=ds.keypoints.to_numpy().tolist(),
**skeleton_kwargs,
)
]

# Group all PoseEstimationSeries into a PoseEstimation object
bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist())
description = (
f"Estimated positions of {bodyparts_str} for "
f"{individual} using {ds.source_software}."
)
pose_estimation = [
ndx_pose.PoseEstimation(
name="PoseEstimation",
pose_estimation_series=pose_estimation_series,
description=description,
source_software=ds.source_software,
skeleton=skeleton_list[-1],
**pose_estimation_kwargs,
)
]
# Create a Skeletons object
skeletons = ndx_pose.Skeletons(skeletons=skeleton_list)

return pose_estimation, skeletons


def _write_behavior_processing_module(
nwb_file: pynwb.NWBFile,
pose_estimation: ndx_pose.PoseEstimation,
skeletons: ndx_pose.Skeletons,
) -> None:
"""Write behaviour processing data to an NWB file.

PoseEstimation or Skeletons objects will be written to the NWB file's
"behavior" processing module, formatted according to the ``ndx-pose`` NWB
extension. If the module does not exist, it will be created.
Data will not overwrite any existing objects in the NWB file.

Parameters
----------
nwb_file : pynwb.NWBFile
The NWB file object to which the data will be added.
pose_estimation : ndx_pose.PoseEstimation
PoseEstimation object containing the pose data for an individual.
skeletons : ndx_pose.Skeletons
Skeletons object containing the skeleton data for an individual.

"""
try:
behavior_pm = nwb_file.create_processing_module(
name="behavior",
description="processed behavioral data",
)
logger.debug("Created behavior processing module in NWB file.")
except ValueError:
logger.debug(

Check warning on line 159 in movement/io/_nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/_nwb.py#L158-L159

Added lines #L158 - L159 were not covered by tests
"Data will be added to existing behavior processing module."
)
behavior_pm = nwb_file.processing["behavior"]

Check warning on line 162 in movement/io/_nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/_nwb.py#L162

Added line #L162 was not covered by tests

try:
behavior_pm.add(skeletons)
logger.info("Added Skeletons object to NWB file.")
except ValueError:
log_warning("Skeletons object already exists. Skipping...")

Check warning on line 168 in movement/io/_nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/_nwb.py#L167-L168

Added lines #L167 - L168 were not covered by tests

try:
behavior_pm.add(pose_estimation)
logger.info("Added PoseEstimation object to NWB file.")
except ValueError:
log_warning("PoseEstimation object already exists. Skipping...")

Check warning on line 174 in movement/io/_nwb.py

View check run for this annotation

Codecov / codecov/patch

movement/io/_nwb.py#L173-L174

Added lines #L173 - L174 were not covered by tests
145 changes: 143 additions & 2 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
import numpy as np
import pandas as pd
import pynwb
import xarray as xr
from sleap_io.io.slp import read_labels
from sleap_io.model.labels import Labels
Expand Down Expand Up @@ -97,7 +98,11 @@
def from_file(
file_path: Path | str,
source_software: Literal[
"DeepLabCut", "SLEAP", "LightningPose", "Anipose"
"DeepLabCut",
"SLEAP",
"LightningPose",
"Anipose",
"NWB",
],
fps: float | None = None,
**kwargs,
Expand All @@ -112,11 +117,14 @@
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
these functions will be called internally, based on
the value of ``source_software``.
source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose"
source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \
"NWB"}
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
The fps argument is ignored when source_software is "NWB", as the
frame rate will be estimated from timestamps in the file.
**kwargs : dict, optional
Additional keyword arguments to pass to the software-specific
loading functions that are listed under "See Also".
Expand All @@ -133,6 +141,7 @@
movement.io.load_poses.from_sleap_file
movement.io.load_poses.from_lp_file
movement.io.load_poses.from_anipose_file
movement.io.load_poses.from_nwb_file

Examples
--------
Expand All @@ -142,6 +151,12 @@
... )

"""
if source_software == "NWB" and fps is not None:
log_warning(

Check warning on line 155 in movement/io/load_poses.py

View check run for this annotation

Codecov / codecov/patch

movement/io/load_poses.py#L155

Added line #L155 was not covered by tests
"The fps argument is ignored when loading from an NWB file. "
"The frame rate will be estimated from timestamps in the file."
)

if source_software == "DeepLabCut":
return from_dlc_file(file_path, fps)
elif source_software == "SLEAP":
Expand All @@ -150,6 +165,8 @@
return from_lp_file(file_path, fps)
elif source_software == "Anipose":
return from_anipose_file(file_path, fps, **kwargs)
elif source_software == "NWB":
return from_nwb_file(file_path, **kwargs)

Check warning on line 169 in movement/io/load_poses.py

View check run for this annotation

Codecov / codecov/patch

movement/io/load_poses.py#L169

Added line #L169 was not covered by tests
else:
raise log_error(
ValueError, f"Unsupported source software: {source_software}"
Expand Down Expand Up @@ -825,3 +842,127 @@
return from_anipose_style_df(
anipose_df, fps=fps, individual_name=individual_name
)


def from_nwb_file(
file: str | Path | pynwb.NWBFile,
key_name: str = "PoseEstimation",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an NWB file.

The input can be a path to an NWB file on disk or an open NWB file object.
The data will be extracted from the NWB file's behavior processing module,
which is assumed to contain a ``PoseEstimation`` object formatted according
to the ``ndx-pose`` NWB extension [1]_.

Parameters
----------
file : str | Path | NWBFile
Path to the NWB file on disk (ending in ".nwb"),
or an open NWBFile object.
key_name: str, optional
Name of the ``PoseEstimation`` object in the NWB "behavior"
processing module, by default "PoseEstimation".

Returns
-------
movement_ds : xr.Dataset
A single-individual ``movement`` dataset containing the pose tracks,
confidence scores, and associated metadata.

References
----------
.. [1] https://github.com/rly/ndx-pose

Examples
--------
Open an NWB file and load pose tracks from the file object:

>>> import pynwb
>>> with pynwb.NWBHDF5IO("path/to/file.nwb", mode="r") as io:
... nwb_file = io.read()
... ds = load_poses.from_nwb_file(nwb_file)

Directly load pose tracks from an NWB file on disk:

>>> from movement.io import load_poses
>>> ds = load_poses.from_nwb_file("path/to/file.nwb")

Load two single-individual datasets from two NWB files and merge them
into a multi-individual dataset:

>>> ds_singles = [
... load_poses.from_nwb_file(f) for f in ["id1.nwb", "id2.nwb"]
... ]
>>> ds_multi = xr.merge(datasets)

"""
if isinstance(file, str | Path):
valid_file = ValidFile(
file, expected_permission="r", expected_suffix=[".nwb"]
)
with pynwb.NWBHDF5IO(valid_file.path, mode="r") as io:
nwb_file = io.read()
ds = _ds_from_nwb_object(nwb_file, key_name=key_name)
ds.attrs["source_file"] = valid_file.path
elif isinstance(file, pynwb.NWBFile):
ds = _ds_from_nwb_object(file, key_name=key_name)
ds.attrs["source_file"] = None
else:
raise log_error(

Check warning on line 912 in movement/io/load_poses.py

View check run for this annotation

Codecov / codecov/patch

movement/io/load_poses.py#L912

Added line #L912 was not covered by tests
TypeError,
"Expected file to be one of following types: str, Path, "
f"pynwb.NWBFile. Got {type(file)} instead.",
)
return ds


def _ds_from_nwb_object(
nwb_file: pynwb.NWBFile,
key_name: str = "PoseEstimation",
) -> xr.Dataset:
"""Extract a ``movement`` poses dataset from an open NWB file object.

Parameters
----------
nwb_file : pynwb.NWBFile
An open NWB file object.
key_name: str
Name of the ``PoseEstimation`` object in the NWB "behavior"
processing module, by default "PoseEstimation".

Returns
-------
movement_ds : xr.Dataset
A single-individual ``movement`` poses dataset

"""
pose_estimation = nwb_file.processing["behavior"][key_name]
source_software = pose_estimation.fields["source_software"]
pose_estimation_series = pose_estimation.fields["pose_estimation_series"]
single_keypoint_datasets = []
for keypoint, pse in pose_estimation_series.items():
# Extract position and confidence data for each keypoint
position_data = np.asarray(pse.data) # shape: (n_time, n_space)
confidence_data = ( # shape: (n_time,)
np.asarray(pse.confidence)
if getattr(pse, "confidence", None) is not None
else np.full(position_data.shape[0], np.nan)
)

# Compute fps from the time differences between timestamps
fps = np.nanmedian(1 / np.diff(pse.timestamps))

single_keypoint_datasets.append(
# create movement dataset with 1 keypoint and 1 individual
from_numpy(
position_data[:, :, np.newaxis, np.newaxis],
confidence_data[:, np.newaxis, np.newaxis],
individual_names=[nwb_file.identifier],
keypoint_names=[keypoint],
fps=round(fps, 6),
source_software=source_software,
)
)

return xr.merge(single_keypoint_datasets)
Loading
Loading