Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I/O support for the ndx-pose NWB extension #166

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
bd12d04
Create nwb_export.py
edeno Apr 18, 2024
c5319b9
NWB requires one file per individual
edeno Apr 18, 2024
d82fe30
Add script
edeno Apr 19, 2024
d889105
Remove import error handling
edeno Apr 19, 2024
72aea47
Add nwb optional dependencies
edeno Apr 19, 2024
12bf83f
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
742bf86
Add example docstring
edeno Apr 19, 2024
a06d485
Rename to fit module naming pattern
edeno Apr 19, 2024
739c4d8
Add import from nwb
edeno Apr 19, 2024
aef9b0c
Merge branch 'main' into main
edeno Apr 22, 2024
ce28f90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
b9599a9
Merge remote-tracking branch 'upstream/main'
edeno Jun 8, 2024
2491cf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
58bef93
Apply suggestions from code review
edeno Jun 8, 2024
3ab9aa4
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
910ce90
Cleanup of docstrings and variable names from code review
edeno Jun 8, 2024
3f9a53b
Rename function for clarity
edeno Jun 8, 2024
3cd991d
Update with example converting back to movement
edeno Jun 8, 2024
3aa1b11
Add file validation and handling for single path
edeno Jun 8, 2024
e56cf6d
Add preliminary tests
edeno Jun 8, 2024
99a90c1
Convert to numpy array
edeno Jun 9, 2024
02b9975
Handle lack of confidence
edeno Jun 9, 2024
a2ac053
Display xarray
edeno Jun 9, 2024
84a495d
Refactor tests
edeno Jun 9, 2024
90c3287
Merge remote-tracking branch 'upstream/main'
edeno Jun 10, 2024
e9e1cef
Create nwb_export.py
edeno Apr 18, 2024
3ccd71c
NWB requires one file per individual
edeno Apr 18, 2024
f906cd5
Add script
edeno Apr 19, 2024
d35d9c2
Remove import error handling
edeno Apr 19, 2024
e5726d4
Add nwb optional dependencies
edeno Apr 19, 2024
53f505b
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
f1d480d
Add example docstring
edeno Apr 19, 2024
4b162cf
Rename to fit module naming pattern
edeno Apr 19, 2024
4b887be
Add import from nwb
edeno Apr 19, 2024
96ee7ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
2f2625d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
1c7c2e3
Apply suggestions from code review
edeno Jun 8, 2024
4191ae8
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
4202ff6
Cleanup of docstrings and variable names from code review
edeno Jun 8, 2024
3188b0b
Rename function for clarity
edeno Jun 8, 2024
4908040
Update with example converting back to movement
edeno Jun 8, 2024
da43e87
Add file validation and handling for single path
edeno Jun 8, 2024
9d34939
Add preliminary tests
edeno Jun 8, 2024
56a6672
Convert to numpy array
edeno Jun 9, 2024
b37b2c6
Handle lack of confidence
edeno Jun 9, 2024
f7d48ce
Display xarray
edeno Jun 9, 2024
0606add
Refactor tests
edeno Jun 9, 2024
dbf804a
Merge branch 'main' of https://github.com/edeno/movement
edeno Oct 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/nwb_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Export pose tracks to NWB
============================

Export pose tracks to NWB
"""
edeno marked this conversation as resolved.
Show resolved Hide resolved

import datetime

from pynwb import NWBFile

from movement import sample_data
from movement.io.nwb import convert_movement_to_nwb

# Load the sample data
ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv")

# The dataset has two individuals
# we will create two NWBFiles for each individual

session_start_time = datetime.datetime.now(datetime.timezone.utc)
nwbfile_individual1 = NWBFile(
session_description="session_description",
identifier="individual1",
session_start_time=session_start_time,
)
nwbfile_individual2 = NWBFile(
session_description="session_description",
identifier="individual2",
session_start_time=session_start_time,
)

nwbfiles = [nwbfile_individual1, nwbfile_individual2]

# Convert the dataset to NWB
# This will create PoseEstimation and Skeleton objects for each individual
# and add them to the NWBFile
convert_movement_to_nwb(nwbfiles, ds)
edeno marked this conversation as resolved.
Show resolved Hide resolved
213 changes: 213 additions & 0 deletions movement/io/nwb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from typing import Optional, Union
edeno marked this conversation as resolved.
Show resolved Hide resolved

import ndx_pose
import numpy as np
import pynwb
import xarray as xr


def _create_pose_and_skeleton_objects(
ds: xr.Dataset,
subject: str,
pose_estimation_series_kwargs: Optional[dict] = None,
pose_estimation_kwargs: Optional[dict] = None,
skeleton_kwargs: Optional[dict] = None,
) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]:
"""Creates PoseEstimation and Skeletons objects from a movement xarray
dataset.
edeno marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
ds : xr.Dataset
edeno marked this conversation as resolved.
Show resolved Hide resolved
Movement dataset containing the data to be converted to NWB.
edeno marked this conversation as resolved.
Show resolved Hide resolved
subject : str
Name of the subject to be converted.
edeno marked this conversation as resolved.
Show resolved Hide resolved
pose_estimation_series_kwargs : dict, optional
PoseEstimationSeries keyword arguments. See ndx_pose, by default None
pose_estimation_kwargs : dict, optional
PoseEstimation keyword arguments. See ndx_pose, by default None
skeleton_kwargs : dict, optional
Skeleton keyword arguments. See ndx_pose, by default None

Returns
-------
pose_estimation : list[ndx_pose.PoseEstimation]
List of PoseEstimation objects
skeletons : ndx_pose.Skeletons
Skeletons object containing all skeletons

"""
if pose_estimation_series_kwargs is None:
pose_estimation_series_kwargs = dict(
reference_frame="(0,0,0) corresponds to ...",
confidence_definition=None,
conversion=1.0,
resolution=-1.0,
offset=0.0,
starting_time=None,
comments="no comments",
description="no description",
control=None,
control_description=None,
)

if skeleton_kwargs is None:
skeleton_kwargs = dict(edges=None)

if pose_estimation_kwargs is None:
pose_estimation_kwargs = dict(
original_videos=None,
labeled_videos=None,
dimensions=None,
devices=None,
scorer=None,
source_software_version=None,
)

pose_estimation_series = []

for keypoint in ds.keypoints.to_numpy():
pose_estimation_series.append(
ndx_pose.PoseEstimationSeries(
name=keypoint,
data=ds.sel(keypoints=keypoint).position.to_numpy(),
confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(),
unit="pixels",
timestamps=ds.sel(keypoints=keypoint).time.to_numpy(),
**pose_estimation_series_kwargs,
)
)

skeleton_list = [
ndx_pose.Skeleton(
name=f"{subject}_skeleton",
nodes=ds.keypoints.to_numpy().tolist(),
**skeleton_kwargs,
)
]

bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist())
description = (
f"Estimated positions of {bodyparts_str} of"
f"{subject} using {ds.source_software}."
)

pose_estimation = [
ndx_pose.PoseEstimation(
name="PoseEstimation",
pose_estimation_series=pose_estimation_series,
description=description,
source_software=ds.source_software,
skeleton=skeleton_list[-1],
**pose_estimation_kwargs,
)
]

skeletons = ndx_pose.Skeletons(skeletons=skeleton_list)

return pose_estimation, skeletons


def convert_movement_to_nwb(
edeno marked this conversation as resolved.
Show resolved Hide resolved
nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile],
ds: xr.Dataset,
pose_estimation_series_kwargs: Optional[dict] = None,
pose_estimation_kwargs: Optional[dict] = None,
skeletons_kwargs: Optional[dict] = None,
):
if isinstance(nwbfiles, pynwb.NWBFile):
nwbfiles = [nwbfiles]

if len(nwbfiles) != len(ds.individuals):
raise ValueError(
"Number of NWBFiles must be equal to the number of individuals. "
"NWB requires one file per individual."
)
edeno marked this conversation as resolved.
Show resolved Hide resolved

for nwbfile, subject in zip(nwbfiles, ds.individuals.to_numpy()):
pose_estimation, skeletons = _create_pose_and_skeleton_objects(
ds.sel(individuals=subject),
subject,
pose_estimation_series_kwargs,
pose_estimation_kwargs,
skeletons_kwargs,
)
try:
behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="processed behavioral data",
)
except ValueError:
print("Behavior processing module already exists. Skipping...")
behavior_pm = nwbfile.processing["behavior"]

try:
behavior_pm.add(skeletons)
except ValueError:
print("Skeletons already exists. Skipping...")
try:
behavior_pm.add(pose_estimation)
except ValueError:
print("PoseEstimation already exists. Skipping...")


def _convert_pse(
edeno marked this conversation as resolved.
Show resolved Hide resolved
pes: ndx_pose.PoseEstimationSeries,
keypoint: str,
subject_name: str,
source_software: str,
source_file: Optional[str] = None,
):
attrs = {
"fps": int(np.median(1 / np.diff(pes.timestamps))),
edeno marked this conversation as resolved.
Show resolved Hide resolved
"time_units": pes.timestamps_unit,
"source_software": source_software,
"source_file": source_file,
}
n_space_dims = pes.data.shape[1]
space_dims = ["x", "y", "z"]

return xr.Dataset(
data_vars={
"position": (
["time", "individuals", "keypoints", "space"],
pes.data[:, np.newaxis, np.newaxis, :],
),
"confidence": (
["time", "individuals", "keypoints"],
pes.confidence[:, np.newaxis, np.newaxis],
),
},
coords={
"time": pes.timestamps,
"individuals": [subject_name],
"keypoints": [keypoint],
"space": space_dims[:n_space_dims],
},
attrs=attrs,
)


def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly for this function, it's public, so a full docstring is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be an assymetry between the two converter functions:
convert_nwb_to_movement reads nwb files from disk, while convert_movement_to_nwb writes to the NWB file handlers, without actually writing to disk.

To be consistent with our other loading/saving functions, I'd prefer to have functions to read from nwb file(s) on disk to movement ds, and functions for doing the exact opposite, i.e. write from movement ds to nwb file(s) on disk.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in our other IO functions we accept both str and pathlib.Path as file paths.
In this particular fuction it would make sense to accept either a single path (Union[Path, str]) (which would work for single-subject datasets), or a list of such paths.

You could alro re-use our movement.io.validators.ValidFile validator for checking the paths, either directly, or via the _validate_file_path utility that can be found in movement.io.save_poses

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my reasoning for not writing from movement dataset to NWB file on disk is because there are often times where you'd like to add other things to the NWB file before writing to disk.

datasets = []
for path in nwb_filepaths:
with pynwb.NWBHDF5IO(path, mode="r") as io:
nwbfile = io.read()
pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"]
source_software = pose_estimation.fields["source_software"]
pose_estimation_series = pose_estimation.fields[
"pose_estimation_series"
]

for keypoint, pes in pose_estimation_series.items():
datasets.append(
_convert_pse(
pes,
keypoint,
subject_name=nwbfile.identifier,
source_software=source_software,
source_file=None,
)
)

return xr.merge(datasets)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dev = [
"types-PyYAML",
"types-requests",
]
nwb = ["pynwb", "ndx-pose"]
edeno marked this conversation as resolved.
Show resolved Hide resolved

[build-system]
requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]
Expand Down
Loading