diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 2f09a8ee..e2c9b5e5 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -62,11 +62,6 @@ jobs: needs: [build_sdist_wheels] runs-on: ubuntu-latest steps: - - uses: actions/download-artifact@v3 + - uses: neuroinformatics-unit/actions/upload_pypi@v2 with: - name: artifact - path: dist - - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - user: __token__ - password: ${{ secrets.TWINE_API_KEY }} + secret-pypi-key: ${{ secrets.TWINE_API_KEY }} diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 293e98b5..78e27563 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -10,6 +10,7 @@ Input/Output .. autosummary:: :toctree: api + from_file from_sleap_file from_dlc_file from_dlc_df diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index fef1f348..0ea2e5f6 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -73,6 +73,11 @@ Then, depending on the source of your data, use one of the following functions: Load from [SLEAP analysis files](sleap:tutorials/analysis) (.h5): ```python ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30) + +# or equivalently +ds = load_poses.from_file( + "/path/to/file.analysis.h5", source_software="SLEAP", fps=30 +) ``` ::: @@ -86,6 +91,11 @@ ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) You may also load .csv files (assuming they are formatted as DeepLabCut expects them): ```python ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) + +# or equivalently +ds = load_poses.from_file( + "/path/to/file.csv", source_software="DeepLabCut", fps=30 +) ``` If you have already imported the data into a pandas DataFrame, you can @@ -103,6 +113,11 @@ ds = load_poses.from_dlc_df(df, fps=30) Load from LightningPose (LP) files (.csv): ```python ds = load_poses.from_lp_file("/path/to/file.analysis.csv", fps=30) + +# or equivalently +ds = load_poses.from_file( + "/path/to/file.analysis.csv", source_software="LightningPose", fps=30 +) ``` ::: diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index ff0aa929..be842273 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -21,6 +21,52 @@ logger = logging.getLogger(__name__) +def from_file( + file_path: Union[Path, str], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + fps: Optional[float] = None, +) -> xr.Dataset: + """Load pose tracking data from a DeepLabCut (DLC), LightningPose (LP) or + SLEAP output file into an xarray Dataset. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing predicted poses. The file format must + be among those supported by the ``from_dlc_file()``, + ``from_slp_file()`` or ``from_lp_file()`` functions, + since one of these functions will be called internally, based on + the value of ``source_software``. + source_software : "DeepLabCut", "SLEAP" or "LightningPose" + The source software of the file. + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + See Also + -------- + movement.io.load_poses.from_dlc_file + movement.io.load_poses.from_sleap_file + movement.io.load_poses.from_lp_file + """ + + if source_software == "DeepLabCut": + return from_dlc_file(file_path, fps) + elif source_software == "SLEAP": + return from_sleap_file(file_path, fps) + elif source_software == "LightningPose": + return from_lp_file(file_path, fps) + else: + raise log_error( + ValueError, f"Unsupported source software: {source_software}" + ) + + def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: """Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame. diff --git a/movement/sample_data.py b/movement/sample_data.py index 03b59d35..c5fdc02e 100644 --- a/movement/sample_data.py +++ b/movement/sample_data.py @@ -179,10 +179,9 @@ def fetch_sample_data( file for file in metadata if file["file_name"] == filename ) - if file_metadata["source_software"] == "SLEAP": - ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"]) - elif file_metadata["source_software"] == "DeepLabCut": - ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"]) - elif file_metadata["source_software"] == "LightningPose": - ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"]) + ds = load_poses.from_file( + file_path, + source_software=file_metadata["source_software"], + fps=file_metadata["fps"], + ) return ds diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 37f76f6a..e6256056 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import h5py import numpy as np import pytest @@ -239,3 +241,25 @@ def test_load_multi_animal_from_lp_file_raises(self): file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv") with pytest.raises(ValueError): load_poses.from_lp_file(file_path) + + @pytest.mark.parametrize( + "source_software", ["SLEAP", "DeepLabCut", "LightningPose", "Unknown"] + ) + @pytest.mark.parametrize("fps", [None, 30, 60.0]) + def test_from_file_delegates_correctly(self, source_software, fps): + """Test that the from_file() function delegates to the correct + loader function according to the source_software.""" + + software_to_loader = { + "SLEAP": "movement.io.load_poses.from_sleap_file", + "DeepLabCut": "movement.io.load_poses.from_dlc_file", + "LightningPose": "movement.io.load_poses.from_lp_file", + } + + if source_software == "Unknown": + with pytest.raises(ValueError, match="Unsupported source"): + load_poses.from_file("some_file", source_software) + else: + with patch(software_to_loader[source_software]) as mock_loader: + load_poses.from_file("some_file", source_software, fps) + mock_loader.assert_called_with("some_file", fps)