Skip to content

Commit

Permalink
Continue adding docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jan 31, 2025
1 parent 7ca3c34 commit 677e155
Show file tree
Hide file tree
Showing 14 changed files with 296 additions and 283 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
slurm_logs/
derivatives/
gallery_builds
sg_execution_times.rst

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
55 changes: 0 additions & 55 deletions docs/source/sg_execution_times.rst

This file was deleted.

24 changes: 13 additions & 11 deletions spikewrap/configs/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ def get_configs(name: str) -> tuple[dict, dict]:
Returns
-------
pp_steps : a dictionary containing the preprocessing
step order (keys) and a [pp_name, kwargs]
list containing the spikeinterface preprocessing
step and keyword options.
pp_steps
a dictionary containing the preprocessing
step order (keys) and a [pp_name, kwargs]
list containing the spikeinterface preprocessing
step and keyword options.
sorter_options : a dictionary with sorter name (key) and
a dictionary of kwargs to pass to the
spikeinterface sorter class.
sorter_options
a dictionary with sorter name (key) and
a dictionary of kwargs to pass to the
spikeinterface sorter class.
"""
config_dir = get_configs_path()

Expand Down Expand Up @@ -126,11 +128,11 @@ def save_config_dict(config_dict: dict, name: str, folder: Path | None = None):
Parameters
----------
config_dict :
config_dict
The configs dictionary to save.
name :
name
The name of the YAML file (with or without the `.yaml` extension).
folder :
folder
If None (default), the config is saved in the spikewrap-managed
user configs folder. Otherwise, save in `folder`.
"""
Expand All @@ -151,7 +153,7 @@ def load_config_dict(filepath: Path) -> dict:
Parameters
----------
filepath :
filepath
The full path to the YAML file, including the file name and extension.
Returns
Expand Down
14 changes: 8 additions & 6 deletions spikewrap/configs/hpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ def default_slurm_options(partition: Literal["cpu", "gpu"] = "cpu") -> dict:
All arguments correspond to sbatch arguments except for:
`wait` : Whether to block the execution of the calling process until the job completes.
``wait``
Whether to block the execution of the calling process until the job completes.
`env_name` : The name of the Conda environment to run the job in. Defaults to the
active Conda environment of the calling process, or "spikewrap" if none is detected.
To modify this, update the returned dictionary directly.
``env_name``
The name of the Conda environment to run the job in. Defaults to the
active Conda environment of the calling process, or "spikewrap" if none is detected.
To modify this, update the returned dictionary directly.
Parameters
----------
partition :
partition
The SLURM partition to use.
Returns
-------
options :
options
Dictionary of SLURM job settings.
"""
env_name = os.environ.get("CONDA_DEFAULT_ENV")
Expand Down
10 changes: 5 additions & 5 deletions spikewrap/process/_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def get_run_paths(
Parameters
----------
file_format :
file_format
The data format of the electrophysiology recordings.
ses_path :
ses_path
The path to the session for which to detect the runs.
passed_run_names :
passed_run_names
The ordered names of the runs to retrieve. If "all", all detected runs are returned.
Otherwise, each run name in the list must match a detected run in the folders.
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_spikeglx_runs(ses_path: Path) -> list[Path]:
Parameters
----------
ses_path :
ses_path
The path to the session for which to detect the runs.
Returns
Expand Down Expand Up @@ -190,7 +190,7 @@ def get_openephys_runs(ses_path: Path) -> list[Path]:
Parameters
----------
ses_path :
ses_path
The path to the session for which to detect the runs.
Returns
Expand Down
81 changes: 14 additions & 67 deletions spikewrap/process/_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from spikeinterface.core import BaseRecording
from typing import Callable

import numpy as np
import spikeinterface.full as si
Expand All @@ -25,9 +22,9 @@ def _fill_with_preprocessed_recordings(
Parameters
----------
preprocess_data :
preprocess_data
Dictionary to store the newly created recording objects, updated in-place.
pp_steps :
pp_steps
"preprocessing" entry of a "configs" dictionary. Formatted as
{step_num_str : [preprocessing_func_name, {pp_func_args}]
"""
Expand All @@ -36,91 +33,41 @@ def _fill_with_preprocessed_recordings(
checked_pp_steps, pp_step_names = _check_and_sort_pp_steps(pp_steps, pp_funcs)

for step_num, pp_info in checked_pp_steps.items():
(
pp_name,
pp_options,
last_pp_step_output,
new_name,
) = _get_preprocessing_step_information(
pp_info, pp_step_names, preprocess_data, step_num
pp_name, pp_options = pp_info

last_pp_step_output, __ = _utils._get_dict_value_from_step_num(
preprocess_data, step_num=str(int(step_num) - 1)
)

preprocessed_recording = pp_funcs[pp_name](last_pp_step_output, **pp_options)

new_name = f"{step_num}-" + "-".join(["raw"] + pp_step_names[: int(step_num)])

preprocess_data[new_name] = preprocessed_recording


# Helpers for preprocessing steps dictionary -------------------------------------------


def _get_preprocessing_step_information(
pp_info: list,
pp_step_names: list[str],
preprocess_data: dict,
step_num: str,
) -> tuple[str, dict, BaseRecording, str]:
"""
Retrieve recording and details needed to apply a preprocessing step.
Extracts the name and options for the current step, determines
the output of the previous step, and constructs a descriptive name for the new
recording.
Parameters
----------
pp_info :
A list containing the preprocessing step name and a dictionary of options.
For example: ["common_reference", {"operator": "median"}].
pp_step_names :
A list of ordered preprocessing step names.
preprocess_data :
The Preprocessed._data dictionary to be filled.
step_num :
The current step number as a string.
Returns
-------
pp_name :
The name of the current preprocessing step.
pp_options :
The dictionary of parameters for the current step.
last_pp_step_output
The preprocessed SpikeInterface recording from the immediately
preceding step, usually a recording from the previous pipeline stage.
new_name :
A concatenated string representing the current preprocessing
step name in sequence, used as a key in `preprocess_data`.
"""
pp_name, pp_options = pp_info

last_pp_step_output, __ = _utils._get_dict_value_from_step_num(
preprocess_data, step_num=str(int(step_num) - 1)
)

new_name = f"{step_num}-" + "-".join(["raw"] + pp_step_names[: int(step_num)])

return pp_name, pp_options, last_pp_step_output, new_name


def _check_and_sort_pp_steps(pp_steps: dict, pp_funcs: dict) -> tuple[dict, list[str]]:
"""
Sort the preprocessing steps dictionary by order to be run
(based on the keys) and check the dictionary is valid.
Parameters
----------
pp_steps : dict
pp_steps dict
"preprocessing" entry of a "configs" dictionary. Formatted as
{step_num_str : [preprocessing_func_name, {pp_func_args}]
pp_funcs :
pp_funcs
A dictionary linking preprocessing step names to the underlying
SpikeInterface preprocessing functions.
Returns
-------
pp_steps :
pp_steps
The checked pp_steps dictionary.
pp_step_names :
pp_step_names
List of ordered preprocessing step names (e.g. "bandpass_filter").
"""
_validate_pp_steps(pp_steps)
Expand All @@ -144,7 +91,7 @@ def _validate_pp_steps(pp_steps: dict) -> None:
Parameters
----------
pp_steps : dict
pp_steps
"preprocessing" entry of a "configs" dictionary. Formatted as
{step_num_str : [preprocessing_func_name, {pp_func_args}]
"""
Expand Down
6 changes: 3 additions & 3 deletions spikewrap/process/_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def save_sync_channel(
Parameters
----------
recording : BaseRecording
recording
The recording object from which to extract the sync channel data.
output_path : Path
output_path
The directory where the sync channel file will be saved.
file_format : {"spikeglx", "openephys"}
file_format
The format of the recording file. Determines how the sync channel is extracted.
Raises
Expand Down
33 changes: 14 additions & 19 deletions spikewrap/structure/_preprocessed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,25 @@

class Preprocessed:
"""
This class represents a single recording and its full preprocessing chain.
This is a class used internally by spikewrap.
This class should be treated as immutable after initialisation.
Preprocessed recordings are held in self._data, a dict with
keys as a str representing the current step in the preprocessing chain,
e.g. "0-raw", "0-raw_1-phase_shift_2-bandpass_filter"
and value the corresponding preprocessed recording.
Class for holding and managing the preprocessing
of a SpikeInterface recording.
Parameters
----------
recording :
The raw SpikeInterface recording (prior to preprocessing).
pp_steps :
Preprocessing configuration dictionary (see configs documentation).
output_path :
Path to output the saved fully preprocessed (i.e. last step in the chain)
recording to, with `save_binary()`.
recording
SpikeInterface raw recording object to be preprocessed.
pp_steps
Dictionary specifying preprocessing steps, see ``configs`` documentation.
output_path
Path where preprocessed recording is to be saved (i.e. run folder).
"""

def __init__(
self, recording: BaseRecording, pp_steps: dict, output_path: Path, name: str
):
# These parameters should be treated as constant and never changed
# during the lifetime of the class. Use the properties (which do not
# expose a setter) for both internal and external calls.
if name == canon.grouped_shankname():
self._preprocessed_path = output_path / canon.preprocessed_folder()
else:
Expand All @@ -53,12 +48,12 @@ def __init__(

def save_binary(self, chunk_duration_s: float = 2.0) -> None:
"""
Save the fully preprocessed data (i.e. last step in the
preprocessing chain) to binary file.
Save the fully preprocessed data (i.e. last step in
the preprocessing chain) to binary file.
Parameters
----------
chunk_duration_s :
chunk_duration_s
Writing chunk size in seconds.
"""
recording, __ = _utils._get_dict_value_from_step_num(self._data, "last")
Expand Down
Loading

0 comments on commit 677e155

Please sign in to comment.