Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide a generic load_poses.from_file() function #107 #110

Merged
merged 14 commits into from
Feb 19, 2024
Merged
9 changes: 2 additions & 7 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ jobs:
needs: [build_sdist_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/download-artifact@v3
- uses: neuroinformatics-unit/actions/upload_pypi@v2
with:
name: artifact
path: dist
- uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.TWINE_API_KEY }}
secret-pypi-key: ${{ secrets.TWINE_API_KEY }}
1 change: 1 addition & 0 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Input/Output
.. autosummary::
:toctree: api

from_file
from_sleap_file
from_dlc_file
from_dlc_df
Expand Down
15 changes: 15 additions & 0 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ Then, depending on the source of your data, use one of the following functions:
Load from [SLEAP analysis files](sleap:tutorials/analysis) (.h5):
```python
ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.analysis.h5", source_software="SLEAP", fps=30
)
```
:::

Expand All @@ -86,6 +91,11 @@ ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30)
You may also load .csv files (assuming they are formatted as DeepLabCut expects them):
```python
ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.csv", source_software="DeepLabCut", fps=30
)
```

If you have already imported the data into a pandas DataFrame, you can
Expand All @@ -103,6 +113,11 @@ ds = load_poses.from_dlc_df(df, fps=30)
Load from LightningPose (LP) files (.csv):
```python
ds = load_poses.from_lp_file("/path/to/file.analysis.csv", fps=30)

# or equivalently
ds = load_poses.from_file(
"/path/to/file.analysis.csv", source_software="LightningPose", fps=30
)
```
:::

Expand Down
46 changes: 46 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,52 @@
logger = logging.getLogger(__name__)


def from_file(
file_path: Union[Path, str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
fps: Optional[float] = None,
) -> xr.Dataset:
"""Load pose tracking data from a DeepLabCut (DLC), LightningPose (LP) or
SLEAP output file into an xarray Dataset.

Parameters
----------
file_path : pathlib.Path or str
Path to the file containing predicted poses. The file format must
be among those supported by the ``from_dlc_file()``,
``from_slp_file()`` or ``from_lp_file()`` functions,
since one of these functions will be called internally, based on
the value of ``source_software``.
source_software : "DeepLabCut", "SLEAP" or "LightningPose"
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
Dataset containing the pose tracks, confidence scores, and metadata.

See Also
--------
movement.io.load_poses.from_dlc_file
movement.io.load_poses.from_sleap_file
movement.io.load_poses.from_lp_file
"""

if source_software == "DeepLabCut":
return from_dlc_file(file_path, fps)
elif source_software == "SLEAP":
return from_sleap_file(file_path, fps)
elif source_software == "LightningPose":
return from_lp_file(file_path, fps)
niksirbi marked this conversation as resolved.
Show resolved Hide resolved
else:
raise log_error(
ValueError, f"Unsupported source software: {source_software}"
)


def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset:
"""Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame.

Expand Down
11 changes: 5 additions & 6 deletions movement/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ def fetch_sample_data(
file for file in metadata if file["file_name"] == filename
)

if file_metadata["source_software"] == "SLEAP":
ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "DeepLabCut":
ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "LightningPose":
ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"])
ds = load_poses.from_file(
file_path,
source_software=file_metadata["source_software"],
fps=file_metadata["fps"],
)
return ds
24 changes: 24 additions & 0 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import h5py
import numpy as np
import pytest
Expand Down Expand Up @@ -239,3 +241,25 @@ def test_load_multi_animal_from_lp_file_raises(self):
file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv")
with pytest.raises(ValueError):
load_poses.from_lp_file(file_path)

@pytest.mark.parametrize(
"source_software", ["SLEAP", "DeepLabCut", "LightningPose", "Unknown"]
)
@pytest.mark.parametrize("fps", [None, 30, 60.0])
def test_from_file_delegates_correctly(self, source_software, fps):
"""Test that the from_file() function delegates to the correct
loader function according to the source_software."""

software_to_loader = {
"SLEAP": "movement.io.load_poses.from_sleap_file",
"DeepLabCut": "movement.io.load_poses.from_dlc_file",
"LightningPose": "movement.io.load_poses.from_lp_file",
}

if source_software == "Unknown":
with pytest.raises(ValueError, match="Unsupported source"):
load_poses.from_file("some_file", source_software)
else:
with patch(software_to_loader[source_software]) as mock_loader:
load_poses.from_file("some_file", source_software, fps)
mock_loader.assert_called_with("some_file", fps)
Loading