From 3a2b94bad42f085903a6265437c526ccac5bc4aa Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 13 Sep 2023 18:28:57 +0100 Subject: [PATCH] Add load_data. --- .gitattributes | 0 .gitignore | 88 +++++++++++++++ .pre-commit-config.yaml | 31 ++++++ LICENSE | 28 +++++ MANIFEST.in | 6 ++ README.md | 0 pyproject.toml | 110 +++++++++++++++++++ spikewrap/data_classes/base.py | 138 ++++++++++++++++++++++++ spikewrap/data_classes/preprocessing.py | 58 ++++++++++ spikewrap/examples/load_data.py | 24 +++++ spikewrap/pipeline/load_data.py | 95 ++++++++++++++++ spikewrap/utils/utils.py | 28 +++++ 12 files changed, 606 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 pyproject.toml create mode 100644 spikewrap/data_classes/base.py create mode 100644 spikewrap/data_classes/preprocessing.py create mode 100644 spikewrap/examples/load_data.py create mode 100644 spikewrap/pipeline/load_data.py create mode 100644 spikewrap/utils/utils.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..e69de29 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..152c86f --- /dev/null +++ b/.gitignore @@ -0,0 +1,88 @@ +# Custom +.obsidian +slurm_logs/ +derivatives/ +tests/data/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask instance folder +instance/ + +# Sphinx documentation +docs/_build/ + +# MkDocs documentation +/site/ + +# PyBuilder +target/ + +# Pycharm and VSCode +.idea/ +venv/ +.vscode/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# OS +.DS_Store + +# written by setuptools_scm +**/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..aaafbb0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +exclude: 'README.md' + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-toml + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: requirements-txt-fixer + - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.280 + hooks: + - id: ruff + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.4.1 + hooks: + - id: mypy + additional_dependencies: + - types-setuptools + - types-PyYAML + - types-toml diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0a063e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ + +Copyright (c) 2023, Joe Ziminski +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of swc_ephys nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..72385b6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +include LICENSE +include README.md +include *.png + +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2d9d946 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,110 @@ +[project] +name = "spikewrap" +authors = [{name = "Joe Ziminski", email= "joseph.j.ziminski@gmail.com"}] +description = "Run extracellular electrophysiology analysis with SpikeInterface" +readme = "README.md" +requires-python = ">=3.8.0" +dynamic = ["version"] + +license = {text = "BSD-3-Clause"} + +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Operating System :: OS Independent", + "License :: OSI Approved :: BSD License", +] + +dependencies = [ + "spikeinterface==0.98.2", + "spython", # I think missing from SI? + "submitit", + "PyYAML", + "toml", + "typeguard", + # sorter-specific + "tridesclous", + # "spyking-circus", TODO: this is not straightforward, requires mpi4py. TBD if we want to manage this. + "mountainsort5", + "docker", # TODO: windows only! + "cuda-python", +] + +[project.urls] +homepage = "https://github.com/JoeZiminski/spikewrap" +bug_tracker = "https://github.com/JoeZiminski/spikewrap/issues" +documentation = "https://github.com/JoeZiminski/spikewrap" +source_code = "https://github.com/JoeZiminski/spikewrap" +user_support = "https://github.com/JoeZiminski/spikewrap/issues" + + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "coverage", + "tox", + "black", + "mypy", + "pre-commit", + "ruff", + "setuptools_scm", + "types-setuptools", + "types-PyYAML", + "types-toml", +] + +[build-system] +requires = [ + "setuptools>=45", + "wheel", + "setuptools_scm[toml]>=6.2", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +include = ["spikewrap*"] +exclude = ["tests*"] + +[tool.pytest.ini_options] +addopts = "--cov=spikewrap" + +[tool.black] +target-version = ['py38', 'py39', 'py310'] +skip-string-normalization = false +line-length = 88 + +[tool.setuptools_scm] + +[tool.check-manifest] +ignore = [ + "*.yaml", + "tox.ini", + "tests/*", + "tests/test_unit/*", + "tests/test_integration/*", + ".flake8" +] + +[tool.ruff] +ignore = ["E501"] # E501: line length violation (let Black handle, ignore strings). + +exclude = ["__init__.py","build",".eggs"] +select = ["I", "E", "F"] +fix = true + +[tool.cibuildwheel] +build = "cp38-* cp39-* cp310-*" + +[tool.cibuildwheel.macos] +archs = ["x86_64", "arm64"] + +[project.scripts] +spikewrap = "spikewrap.command_line_interface:main" diff --git a/spikewrap/data_classes/base.py b/spikewrap/data_classes/base.py new file mode 100644 index 0000000..86557cb --- /dev/null +++ b/spikewrap/data_classes/base.py @@ -0,0 +1,138 @@ +import fnmatch +from collections import UserDict +from collections.abc import ItemsView, KeysView, ValuesView +from dataclasses import dataclass +from itertools import chain +from pathlib import Path +from typing import Callable, Dict, List, Literal + + +@dataclass +class BaseUserDict(UserDict): + """ + Base class for `PreprocessingData` and `SortingData` + used for checking and formatting `base_path`, `sub_name` + and `run_names`. The layout of the `rawdata` and + `derivatives` folder is identical up to the run + folder, allowing use of this class for + preprocessing and sorting. + + Base UserDict that implements the + keys(), values() and items() convenience functions.""" + + base_path: Path + sub_name: str + sessions_and_runs: Dict + + def __post_init__(self) -> None: + self.data: Dict = {} + self.base_path = Path(self.base_path) + self.check_run_names_are_formatted_as_list() + + def check_run_names_are_formatted_as_list(self) -> None: + """""" + for key, value in self.sessions_and_runs.items(): + if not isinstance(value, List): + assert isinstance( + value, str + ), "Run names must be string or list of strings" + self.sessions_and_runs[key] = [value] + + def preprocessing_sessions_and_runs(self): # TODO: type hint + """""" + ordered_ses_names = list( + chain(*[[ses] * len(runs) for ses, runs in self.sessions_and_runs.items()]) + ) + ordered_run_names = list( + chain(*[runs for runs in self.sessions_and_runs.values()]) + ) + + return list(zip(ordered_ses_names, ordered_run_names)) + + def _validate_inputs( + self, + top_level_folder: Literal["rawdata", "derivatives"], + get_top_level_folder: Callable, + get_sub_level_folder: Callable, + get_sub_path: Callable, + get_run_path: Callable, + ) -> None: + """ + Check the rawdata / derivatives path, subject path exists + and ensure run_names is a list of strings. + + Parameters + ---------- + run_names : List[str] + List of run names to process, in order they should be + processed / concatenated. + + Returns + ------- + run_names : List[str] + Validated `run_names` as a List. + """ + assert get_top_level_folder().is_dir(), ( + f"Ensure there is a folder in base path called '" + f"{top_level_folder}'.\n" + f"No {top_level_folder} directory found at " + f"{get_top_level_folder()}\n" + f"where subject-level folders must be placed." + ) + + assert get_sub_level_folder().is_dir(), ( + f"Subject directory not found. {self.sub_name} " + f"is not a folder in {get_top_level_folder()}" + ) + + for ses_name in self.sessions_and_runs.keys(): + assert ( + ses_path := get_sub_path(ses_name) + ).is_dir(), f"{ses_name} was not found at folder path {ses_path}" + + for run_name in self.sessions_and_runs[ses_name]: + assert (run_path := get_run_path(ses_name, run_name)).is_dir(), ( + f"The run folder {run_path.stem} cannot be found at " + f"file path {run_path.parent}." + ) + + gate_str = fnmatch.filter(run_name.split("_"), "g?") + + assert len(gate_str) > 0, ( + f"The SpikeGLX gate index should be in the run name. " + f"It was not found in the name {run_name}." + f"\nEnsure the gate number is in the SpikeGLX-output filename." + ) + + assert len(gate_str) == 1, ( + f"The SpikeGLX gate appears in the name " + f"{run_name} more than once" + ) + + assert int(gate_str[0][1:]) == 0, ( + f"Gate with index larger than 0 is not supported. This is found " + f"in run name {run_name}. " + ) + + # Rawdata Paths -------------------------------------------------------------- + + def get_rawdata_top_level_path(self) -> Path: + return self.base_path / "rawdata" + + def get_rawdata_sub_path(self) -> Path: + return self.get_rawdata_top_level_path() / self.sub_name + + def get_rawdata_ses_path(self, ses_name: str) -> Path: + return self.get_rawdata_sub_path() / ses_name + + def get_rawdata_run_path(self, ses_name: str, run_name: str) -> Path: + return self.get_rawdata_ses_path(ses_name) / "ephys" / run_name + + def keys(self) -> KeysView: + return self.data.keys() + + def items(self) -> ItemsView: + return self.data.items() + + def values(self) -> ValuesView: + return self.data.values() diff --git a/spikewrap/data_classes/preprocessing.py b/spikewrap/data_classes/preprocessing.py new file mode 100644 index 0000000..632ca83 --- /dev/null +++ b/spikewrap/data_classes/preprocessing.py @@ -0,0 +1,58 @@ +import shutil +from dataclasses import dataclass +from typing import Dict + +import spikeinterface + +from ..utils import utils +from .base import BaseUserDict + + +@dataclass +class PreprocessingData(BaseUserDict): + """ + Dictionary to store SpikeInterface preprocessing recordings. + + Details on the preprocessing steps are held in the dictionary keys e.g. + e.g. 0-raw, 1-raw-bandpass_filter, 2-raw_bandpass_filter-common_average + and recording objects are held in the value. These are generated + by the `pipeline.preprocess.run_preprocessing()` function. + + The class manages paths to raw data and preprocessing output, + as defines methods to dump key information and the SpikeInterface + binary to disk. Note that SI preprocessing is lazy and + preprocessing only run when the recording.get_traces() + is called, or the data is saved to binary. + + Parameters + ---------- + base_path : Union[Path, str] + Path where the rawdata folder containing subjects. + + sub_name : str + 'subject' to preprocess. The subject top level dir should + reside in base_path/rawdata/. + + run_names : Union[List[str], str] + The SpikeGLX run name (i.e. not including the gate index) + or list of run names. + """ + + def __post_init__(self) -> None: + super().__post_init__() + self._validate_rawdata_inputs() + + self.sync: Dict = {} + + for ses_name, run_name in self.preprocessing_sessions_and_runs(): + utils.update(self.data, ses_name, run_name, {"0-raw": None}) + utils.update(self.sync, ses_name, run_name, None) + + def _validate_rawdata_inputs(self) -> None: + self._validate_inputs( + "rawdata", + self.get_rawdata_top_level_path, + self.get_rawdata_sub_path, + self.get_rawdata_ses_path, + self.get_rawdata_run_path, + ) diff --git a/spikewrap/examples/load_data.py b/spikewrap/examples/load_data.py new file mode 100644 index 0000000..b46a318 --- /dev/null +++ b/spikewrap/examples/load_data.py @@ -0,0 +1,24 @@ +from pathlib import Path + +from spikewrap.pipeline.load_data import load_data + +base_path = Path( + # r"/ceph/neuroinformatics/neuroinformatics/scratch/jziminski/ephys/test_data/steve_multi_run/1119617/time-short-multises" + r"X:\neuroinformatics\scratch\jziminski\ephys\test_data\steve_multi_run\1119617\time-short-multises" +) + +sub_name = "sub-1119617" +sessions_and_runs = { + "ses-001": [ + "run-001_1119617_LSE1_shank12_g0", + "run-002_made_up_g0", + ], + "ses-002": [ + "run-001_1119617_pretest1_shank12_g0", + ], + "ses-003": [ + "run-002_1119617_pretest1_shank12_g0", + ], +} + +loaded_data = load_data(base_path, sub_name, sessions_and_runs, data_format="spikeglx") diff --git a/spikewrap/pipeline/load_data.py b/spikewrap/pipeline/load_data.py new file mode 100644 index 0000000..42ac621 --- /dev/null +++ b/spikewrap/pipeline/load_data.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List, Union + +import spikeinterface.extractors as se + +from ..data_classes.preprocessing import PreprocessingData +from ..utils import utils + + +def load_data( + base_path: Union[Path, str], + sub_name: str, + sessions_and_runs: Dict[str, List[str]], + data_format: str = "spikeglx", +) -> PreprocessingData: + """ + Load raw data (in rawdata). If multiple runs are selected + in run_names, these will be stored as segments on a SpikeInterface + recording object. + + Parameters + ----------- + + base_path : Union[Path, str] + Path where the rawdata folder containing subjects. + + sub_name : str + Subject to preprocess. The subject top level dir should reside in + base_path/rawdata/ . + + sessions_and_runs : Dict[str, Union[str, List]] + A dictionary containing the sessions and runs to run through the pipeline. + Each session should be a session-level folder name residing in the passed + `sub_name` folder. Each session to run is a key in the + `sessions_and_runs` dict. + For each session key, the value can be a single run name (str) + or a list of run names. The runs will be processed in the + order passed. + + data_format : str + The data type format to load. Currently only "spikeglx" is accepted. + + Returns + ------- + + PreprocessingData class containing SpikeInterface recording object and information + on the data filepaths. + + TODO + ---- + Figure out the format from the data itself, instead of passing as argument. + Do this when adding the next supported format. + """ + empty_data_class = PreprocessingData(Path(base_path), sub_name, sessions_and_runs) + + if data_format == "spikeglx": + return _load_spikeglx_data(empty_data_class) + + raise RuntimeError("`data_format` not recognised.") + + +# -------------------------------------------------------------------------------------- +# Format-specific Loaders +# -------------------------------------------------------------------------------------- + + +def _load_spikeglx_data(preprocess_data: PreprocessingData) -> PreprocessingData: + """ + Load raw SpikeGLX data (in rawdata). If multiple runs are selected + in run_names, these will be stored as segments on a SpikeInterface + recording object. + + See load_data() for parameters. + """ + for ses_name, run_name in preprocess_data.preprocessing_sessions_and_runs(): + run_path = preprocess_data.get_rawdata_run_path(ses_name, run_name) + assert run_name == run_path.name, "TODO" + + with_sync, without_sync = [ + se.read_spikeglx( + folder_path=run_path, + stream_id="imec0.ap", + all_annotations=True, + load_sync_channel=sync, + ) + for sync in [True, False] + ] + preprocess_data[ses_name][run_name]["0-raw"] = without_sync + preprocess_data.sync[ses_name][run_name] = with_sync + + utils.message_user(f"Raw session data was loaded from {run_path}") + + return preprocess_data diff --git a/spikewrap/utils/utils.py b/spikewrap/utils/utils.py new file mode 100644 index 0000000..4ec7e80 --- /dev/null +++ b/spikewrap/utils/utils.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import copy +import os +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Tuple, Union + +import numpy as np +import yaml + + +def update(dict_, ses_name, run_name, value): + try: + dict_[ses_name][run_name] = value + except KeyError: + dict_[ses_name] = {run_name: value} + +def message_user(message: str) -> None: + """ + Method to interact with user. + + Parameters + ---------- + message : str + Message to print. + """ + print(f"\n{message}")