Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Oct 25, 2024
1 parent f7d48ce commit 0606add
Showing 1 changed file with 161 additions and 154 deletions.
315 changes: 161 additions & 154 deletions tests/test_unit/test_nwb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime

import ndx_pose
import numpy as np
import pynwb
import pytest
import xarray as xr
from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons
from pynwb import NWBHDF5IO, NWBFile
from pynwb.file import Subject

from movement import sample_data
from movement.io.nwb import (
Expand All @@ -19,8 +21,8 @@ def test_create_pose_and_skeleton_objects():

# Call the function
pose_estimation, skeletons = _create_pose_and_skeleton_objects(
ds,
subject="subject1",
ds.sel(individuals="individual1"),
subject="individual1",
pose_estimation_series_kwargs=None,
pose_estimation_kwargs=None,
skeleton_kwargs=None,
Expand All @@ -34,43 +36,59 @@ def test_create_pose_and_skeleton_objects():
assert len(pose_estimation) == 1

# Assert the length of pose_estimation_series list
assert len(pose_estimation[0].pose_estimation_series) == 2
assert len(pose_estimation[0].pose_estimation_series) == 12

# Assert the name of the first PoseEstimationSeries
assert pose_estimation[0].pose_estimation_series[0].name == "keypoint1"

# Assert the name of the second PoseEstimationSeries
assert pose_estimation[0].pose_estimation_series[1].name == "keypoint2"
assert "snout" in pose_estimation[0].pose_estimation_series

# Assert the name of the Skeleton
assert skeletons.skeletons[0].name == "subject1_skeleton"
assert "individual1_skeleton" in skeletons.skeletons


def create_test_pose_estimation_series(
n_time=100, n_dims=2, keypoint="front_left_paw"
):
data = np.random.rand(
n_time, n_dims
) # num_frames x (x, y) but can be (x, y, z)
timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame
confidence = np.ones((n_time,)) # a confidence value for every frame
reference_frame = "(0,0,0) corresponds to ..."
confidence_definition = "Softmax output of the deep neural network."

return PoseEstimationSeries(
name=keypoint,
description="Marker placed around fingers of front left paw.",
data=data,
unit="pixels",
reference_frame=reference_frame,
timestamps=timestamps,
confidence=confidence,
confidence_definition=confidence_definition,
)


def test__convert_pose_estimation_series():
# Create a sample PoseEstimationSeries object
pose_estimation_series = ndx_pose.PoseEstimationSeries(
name="keypoint1",
data=np.random.rand(10, 3),
confidence=np.random.rand(10),
unit="pixels",
timestamps=np.arange(10),
pose_estimation_series = create_test_pose_estimation_series(
n_time=100, n_dims=2, keypoint="front_left_paw"
)

# Call the function
movement_dataset = _convert_pose_estimation_series(
pose_estimation_series,
keypoint="keypoint1",
subject_name="subject1",
keypoint="leftear",
subject_name="individual1",
source_software="software1",
source_file="file1",
)

# Assert the dimensions of the movement dataset
assert movement_dataset.dims == {
"time": 10,
assert movement_dataset.sizes == {
"time": 100,
"individuals": 1,
"keypoints": 1,
"space": 3,
"space": 2,
}

# Assert the values of the position variable
Expand All @@ -92,155 +110,144 @@ def test__convert_pose_estimation_series():
"source_software": "software1",
"source_file": "file1",
}
pose_estimation_series = create_test_pose_estimation_series(
n_time=50, n_dims=3, keypoint="front_left_paw"
)

# Assert the dimensions of the movement dataset
assert movement_dataset.sizes == {
"time": 50,
"individuals": 1,
"keypoints": 1,
"space": 3,
}


def test_add_movement_dataset_to_nwb_single_file():
# Create a sample NWBFile
nwbfile = pynwb.NWBFile(
"session_description", "identifier", "session_start_time"
)
# Create a sample movement dataset
movement_dataset = xr.Dataset(
{
"keypoints": (["keypoints"], ["keypoint1", "keypoint2"]),
"position": (["time", "keypoints"], [[1, 2], [3, 4]]),
"confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]),
"time": [0, 1],
"individuals": ["subject1"],
}
ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv")
session_start_time = datetime.datetime.now(datetime.timezone.utc)
nwbfile_individual1 = NWBFile(
session_description="session_description",
identifier="individual1",
session_start_time=session_start_time,
)
add_movement_dataset_to_nwb(
nwbfile_individual1, ds.sel(individuals=["individual1"])
)
assert (
"PoseEstimation"
in nwbfile_individual1.processing["behavior"].data_interfaces
)
assert (
"Skeletons"
in nwbfile_individual1.processing["behavior"].data_interfaces
)
# Call the function
add_movement_dataset_to_nwb(nwbfile, movement_dataset)
# Assert the presence of PoseEstimation and Skeletons in the NWBFile
assert "PoseEstimation" in nwbfile.processing["behavior"]
assert "Skeletons" in nwbfile.processing["behavior"]


def test_add_movement_dataset_to_nwb_multiple_files():
# Create sample NWBFiles
nwbfiles = [
pynwb.NWBFile(
"session_description1", "identifier1", "session_start_time1"
),
pynwb.NWBFile(
"session_description2", "identifier2", "session_start_time2"
),
]
# Create a sample movement dataset
movement_dataset = xr.Dataset(
{
"keypoints": (["keypoints"], ["keypoint1", "keypoint2"]),
"position": (["time", "keypoints"], [[1, 2], [3, 4]]),
"confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]),
"time": [0, 1],
"individuals": ["subject1", "subject2"],
}
ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv")
session_start_time = datetime.datetime.now(datetime.timezone.utc)
nwbfile_individual1 = NWBFile(
session_description="session_description",
identifier="individual1",
session_start_time=session_start_time,
)
nwbfile_individual2 = NWBFile(
session_description="session_description",
identifier="individual2",
session_start_time=session_start_time,
)
# Call the function
add_movement_dataset_to_nwb(nwbfiles, movement_dataset)
# Assert the presence of PoseEstimation and Skeletons in each NWBFile
for nwbfile in nwbfiles:
assert "PoseEstimation" in nwbfile.processing["behavior"]
assert "Skeletons" in nwbfile.processing["behavior"]

nwbfiles = [nwbfile_individual1, nwbfile_individual2]
add_movement_dataset_to_nwb(nwbfiles, ds)

def test_convert_nwb_to_movement():
# Create sample NWB files
nwb_filepaths = [
"/path/to/file1.nwb",
"/path/to/file2.nwb",
"/path/to/file3.nwb",
]
pose_estimation_series = {
"keypoint1": ndx_pose.PoseEstimationSeries(
name="keypoint1",
data=np.random.rand(10, 3),
confidence=np.random.rand(10),
unit="pixels",
timestamps=np.arange(10),
),
"keypoint2": ndx_pose.PoseEstimationSeries(
name="keypoint2",
data=np.random.rand(10, 3),
confidence=np.random.rand(10),
unit="pixels",
timestamps=np.arange(10),
),
}

# Mock the NWBHDF5IO read method
def mock_read(filepath):
nwbfile = pynwb.NWBFile(
"session_description", "identifier", "session_start_time"
)

pose_estimation = ndx_pose.PoseEstimation(
name="PoseEstimation",
pose_estimation_series=pose_estimation_series,
description="Pose estimation data",
source_software="software1",
skeleton=ndx_pose.Skeleton(
name="skeleton1", nodes=["node1", "node2"]
),
)
behavior_pm = pynwb.ProcessingModule(
name="behavior", description="Behavior data"
)
behavior_pm.add(pose_estimation)
nwbfile.add_processing_module(behavior_pm)
return nwbfile
def create_test_pose_nwb(identifier="subject1", write_to_disk=False):
# initialize an NWBFile object
nwbfile = NWBFile(
session_description="session_description",
identifier=identifier,
session_start_time=datetime.datetime.now(datetime.timezone.utc),
)

# Patch the NWBHDF5IO read method with the mock
with pytest.patch("pynwb.NWBHDF5IO.read", side_effect=mock_read):
# Call the function
movement_dataset = convert_nwb_to_movement(nwb_filepaths)
# add a subject to the NWB file
subject = Subject(subject_id=identifier, species="Mus musculus")
nwbfile.subject = subject

# Assert the dimensions of the movement dataset
assert movement_dataset.dims == {
"time": 10,
"individuals": 3,
"keypoints": 2,
"space": 3,
}
skeleton = Skeleton(
name="subject1_skeleton",
nodes=["front_left_paw", "body", "front_right_paw"],
edges=np.array([[0, 1], [1, 2]], dtype="uint8"),
subject=subject,
)

# Assert the values of the position variable
np.testing.assert_array_equal(
movement_dataset["position"].values,
np.concatenate(
[
pose_estimation_series["keypoint1"].data[
:, np.newaxis, np.newaxis, :
],
pose_estimation_series["keypoint2"].data[
:, np.newaxis, np.newaxis, :
],
],
axis=1,
),
skeletons = Skeletons(skeletons=[skeleton])

# create a device for the camera
camera1 = nwbfile.create_device(
name="camera1",
description="camera for recording behavior",
manufacturer="my manufacturer",
)

# Assert the values of the confidence variable
np.testing.assert_array_equal(
movement_dataset["confidence"].values,
np.concatenate(
[
pose_estimation_series["keypoint1"].confidence[
:, np.newaxis, np.newaxis
],
pose_estimation_series["keypoint2"].confidence[
:, np.newaxis, np.newaxis
],
],
axis=1,
),
n_time = 100
n_dims = 2 # 2D data
front_left_paw = create_test_pose_estimation_series(
n_time=n_time, n_dims=n_dims, keypoint="front_left_paw"
)

# Assert the attributes of the movement dataset
assert movement_dataset.attrs == {
"fps": np.nanmedian(
1 / np.diff(pose_estimation_series["keypoint1"].timestamps)
body = create_test_pose_estimation_series(
n_time=n_time, n_dims=n_dims, keypoint="body"
)
front_right_paw = create_test_pose_estimation_series(
n_time=n_time, n_dims=n_dims, keypoint="front_right_paw"
)

# store all PoseEstimationSeries in a list
pose_estimation_series = [front_left_paw, body, front_right_paw]

pose_estimation = PoseEstimation(
name="PoseEstimation",
pose_estimation_series=pose_estimation_series,
description=(
"Estimated positions of front paws" "of subject1 using DeepLabCut."
),
"time_units": pose_estimation_series["keypoint1"].timestamps_unit,
"source_software": "software1",
"source_file": None,
original_videos=["path/to/camera1.mp4"],
labeled_videos=["path/to/camera1_labeled.mp4"],
dimensions=np.array(
[[640, 480]], dtype="uint16"
), # pixel dimensions of the video
devices=[camera1],
scorer="DLC_resnet50_openfieldOct30shuffle1_1600",
source_software="DeepLabCut",
source_software_version="2.3.8",
skeleton=skeleton, # link to the skeleton object
)

behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="processed behavioral data",
)
behavior_pm.add(skeletons)
behavior_pm.add(pose_estimation)

# write the NWBFile to disk
if write_to_disk:
path = "test_pose.nwb"
with NWBHDF5IO(path, mode="w") as io:
io.write(nwbfile)
else:
return nwbfile


def test_convert_nwb_to_movement():
create_test_pose_nwb(write_to_disk=True)
nwb_filepaths = ["test_pose.nwb"]
movement_dataset = convert_nwb_to_movement(nwb_filepaths)

assert movement_dataset.sizes == {
"time": 100,
"individuals": 1,
"keypoints": 3,
"space": 2,
}

0 comments on commit 0606add

Please sign in to comment.