diff --git a/src/roiextractors/extraction_tools.py b/src/roiextractors/extraction_tools.py index e3f4b9ee..dfe2a50d 100644 --- a/src/roiextractors/extraction_tools.py +++ b/src/roiextractors/extraction_tools.py @@ -31,6 +31,7 @@ DtypeType = DTypeLike IntType = Union[int, np.integer] FloatType = float +NoneType = type(None) def raise_multi_channel_or_depth_not_implemented(extractor_name: str): @@ -244,18 +245,16 @@ def read_numpy_memmap_video( return video_memap -def _pixel_mask_extractor(image_mask_, _roi_ids) -> list: - """Convert image mask to pixel mask. +def _pixel_mask_extractor(image_masks: np.ndarray) -> list: + """Convert image masks to pixel masks. Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images. The location and weight of each non-zero pixel is stored for each mask. Parameters ---------- - image_mask_: numpy.ndarray + image_masks: numpy.ndarray Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois). - _roi_ids: list - List of roi ids with length number_of_rois. Returns ------- @@ -265,11 +264,11 @@ def _pixel_mask_extractor(image_mask_, _roi_ids) -> list: the pixel. """ pixel_mask_list = [] - for i, roiid in enumerate(_roi_ids): - image_mask = np.array(image_mask_[:, :, i]) - _locs = np.where(image_mask > 0) - _pix_values = image_mask[image_mask > 0] - pixel_mask_list.append(np.vstack((_locs[0], _locs[1], _pix_values)).T) + for i in range(image_masks.shape[2]): + image_mask = image_masks[:, :, i] + locs = np.where(image_mask > 0) + pix_values = image_mask[image_mask > 0] + pixel_mask_list.append(np.vstack((locs[0], locs[1], pix_values)).T) return pixel_mask_list @@ -684,3 +683,31 @@ def get_package( f"\nThe required package'{package_name}' is not installed!\n" f"To install this package, please run\n\n\t{installation_instructions}\n" ) + + +def get_default_roi_locations_from_image_masks(image_masks: np.ndarray) -> np.ndarray: + """Calculate the default ROI locations from given image masks. + + This function takes a 3D numpy array of image masks and computes the median + coordinates of the maximum values in each 2D mask. The result is a 2D numpy + array where each column represents the (x, y) coordinates of the ROI for + each mask. + + Parameters + ---------- + image_masks : np.ndarray + A 3D numpy array of shape (height, width, num_rois) containing the image masks. + + Returns + ------- + np.ndarray + A 2D numpy array of shape (2, num_rois) where each column contains the + (x, y) coordinates of the ROI for each mask. + """ + num_rois = image_masks.shape[2] + roi_locations = np.zeros([2, num_rois], dtype="int") + for i in range(num_rois): + image_mask = image_masks[:, :, i] + max_value_indices = np.where(image_mask == np.amax(image_mask)) + roi_locations[:, i] = np.array([np.median(max_value_indices[0]), np.median(max_value_indices[1])]).T + return roi_locations diff --git a/src/roiextractors/extractors/numpyextractors/numpyextractors.py b/src/roiextractors/extractors/numpyextractors/numpyextractors.py index 2b738995..1d82ad39 100644 --- a/src/roiextractors/extractors/numpyextractors/numpyextractors.py +++ b/src/roiextractors/extractors/numpyextractors/numpyextractors.py @@ -3,23 +3,30 @@ Classes ------- NumpyImagingExtractor - An ImagingExtractor specified by timeseries .npy file, sampling frequency, and channel names. + An ImagingExtractor specified by timeseries np.ndarray or .npy file and sampling frequency. NumpySegmentationExtractor A Segmentation extractor specified by image masks and traces .npy files. """ from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, get_args import numpy as np -from ...extraction_tools import PathType, FloatType, ArrayType, IntType +from ...extraction_tools import ( + PathType, + FloatType, + ArrayType, + IntType, + NoneType, + get_default_roi_locations_from_image_masks, +) from ...imagingextractor import ImagingExtractor from ...segmentationextractor import SegmentationExtractor class NumpyImagingExtractor(ImagingExtractor): - """An ImagingExtractor specified by timeseries .npy file, sampling frequency, and channel names.""" + """An ImagingExtractor specified by timeseries np.ndarray or .npy file and sampling frequency.""" extractor_name = "NumpyImagingExtractor" installed = True @@ -37,8 +44,8 @@ def __init__(self, timeseries: Union[PathType, np.ndarray], sampling_frequency: Sampling frequency of the video in Hz. """ super().__init__() - - if isinstance(timeseries, (str, Path)): + # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + if isinstance(timeseries, get_args(PathType)): timeseries = Path(timeseries) if timeseries.is_file(): assert timeseries.suffix == ".npy", "'timeseries' file is not a numpy file (.npy)" @@ -97,209 +104,282 @@ class NumpySegmentationExtractor(SegmentationExtractor): def __init__( self, - image_masks, - raw=None, - dff=None, - deconvolved=None, - neuropil=None, - accepted_lst=None, - mean_image=None, - correlation_image=None, - roi_ids=None, - roi_locations=None, - sampling_frequency=None, - rejected_list=None, - channel_names=None, - movie_dims=None, + image_masks: Union[PathType, np.ndarray], + roi_response_traces: dict[str, Union[PathType, np.ndarray]], + sampling_frequency: FloatType, + roi_ids: Optional[list] = None, + accepted_roi_ids: Optional[list] = None, + rejected_roi_ids: Optional[list] = None, + roi_locations: Optional[ArrayType] = None, + summary_images: Optional[dict[str, Union[PathType, np.ndarray]]] = None, + background_image_masks: Optional[Union[PathType, np.ndarray]] = None, + background_response_traces: Optional[dict[str, Union[PathType, np.ndarray]]] = None, + background_ids: Optional[list] = None, ): - """Create a NumpySegmentationExtractor from a .npy file. + """Create a NumpySegmentationExtractor from a set of .npy files or a set of np.ndarrays. Parameters ---------- - image_masks: np.ndarray - Binary image for each of the regions of interest - raw: np.ndarray - Fluorescence response of each of the ROI in time - dff: np.ndarray - DfOverF response of each of the ROI in time - deconvolved: np.ndarray - deconvolved response of each of the ROI in time - neuropil: np.ndarray - neuropil response of each of the ROI in time - mean_image: np.ndarray - Mean image - correlation_image: np.ndarray - correlation image - roi_ids: int list - Unique ids of the ROIs if any - roi_locations: np.ndarray - x and y location representative of ROI mask - sampling_frequency: float - Frame rate of the movie - rejected_list: list - list of ROI ids that are rejected manually or via automated rejection - channel_names: list - list of strings representing channel names - movie_dims: tuple - height x width of the movie - """ - SegmentationExtractor.__init__(self) - if isinstance(image_masks, (str, Path)): - image_masks = Path(image_masks) - if image_masks.is_file(): - assert image_masks.suffix == ".npy", "'image_masks' file is not a numpy file (.npy)" + image_masks: Union[PathType, np.ndarray] + Binary image for each of the regions of interest. + roi_response_traces: dict[str, Union[PathType, np.ndarray]] + Dictionary containing the fluorescence response of each ROI in time. + sampling_frequency: FloatType + Frame rate of the movie. + roi_ids: Optional[list] + Unique ids of the ROIs. If None, then the indices are used. + accepted_roi_ids: Optional[list] + List of ROI ids that are accepted. If None, then all ROI ids are accepted. + rejected_roi_ids: Optional[list] + List of ROI ids that are rejected manually or via automated rejection. If None, then no ROI ids are rejected. + roi_locations: Optional[ArrayType] + x and y location representative of ROI mask. If None, then the maximum location is used. + summary_images: Optional[dict[str, Union[PathType, np.ndarray]]] + Dictionary containing summary images like mean image, correlation image, etc. + background_image_masks: Optional[Union[PathType, np.ndarray]] + Binary image for each of the background components. + background_response_traces: Optional[dict[str, Union[PathType, np.ndarray]]] + Dictionary containing the background response of each component in time. + background_ids: Optional[list] + Unique ids of the background components. If None, then the indices are used. - self.is_dumpable = True - self._image_masks = np.load(image_masks, mmap_mode="r") - - if raw is not None: - raw = Path(raw) - assert raw.suffix == ".npy", "'raw' file is not a numpy file (.npy)" - self._roi_response_raw = np.load(raw, mmap_mode="r") - if dff is not None: - dff = Path(dff) - assert dff.suffix == ".npy", "'dff' file is not a numpy file (.npy)" - self._roi_response_dff = np.load(dff, mmap_mode="r") - self._roi_response_neuropil = np.load(neuropil, mmap_mode="r") - if deconvolved is not None: - deconvolved = Path(deconvolved) - assert deconvolved.suffix == ".npy", "'deconvolved' file is not a numpy file (.npy)" - self._roi_response_deconvolved = np.load(deconvolved, mmap_mode="r") - if neuropil is not None: - neuropil = Path(neuropil) - assert neuropil.suffix == ".npy", "'neuropil' file is not a numpy file (.npy)" - self._roi_response_neuropil = np.load(neuropil, mmap_mode="r") - - self._kwargs = {"image_masks": str(Path(image_masks).absolute())} - if raw is not None: - self._kwargs.update({"raw": str(Path(raw).absolute())}) - if raw is not None: - self._kwargs.update({"dff": str(Path(dff).absolute())}) - if raw is not None: - self._kwargs.update({"neuropil": str(Path(neuropil).absolute())}) - if raw is not None: - self._kwargs.update({"deconvolved": str(Path(deconvolved).absolute())}) + Notes + ----- + If any of image_masks, roi_response_traces, summary_images, background_image_masks, or background_response_traces + are .npy files, then the rest of them must be .npy files as well. + """ + super().__init__() + self._sampling_frequency = float(sampling_frequency) + # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + if isinstance(image_masks, get_args(PathType)): + self._init_from_npy( + image_masks=image_masks, + roi_response_traces=roi_response_traces, + summary_images=summary_images, + background_image_masks=background_image_masks, + background_response_traces=background_response_traces, + ) - else: - raise ValueError("'timeeseries' is does not exist") elif isinstance(image_masks, np.ndarray): - NoneType = type(None) - assert isinstance(raw, (np.ndarray, NoneType)) - assert isinstance(dff, (np.ndarray, NoneType)) - assert isinstance(neuropil, (np.ndarray, NoneType)) - assert isinstance(deconvolved, (np.ndarray, NoneType)) - self.is_dumpable = False - self._image_masks = image_masks - self._roi_response_raw = raw - assert self._image_masks.shape[-1] == self._roi_response_raw.shape[-1], ( - "Inconsistency between image masks and raw traces. " - "Image masks must be (px, py, num_rois), " - "traces must be (num_frames, num_rois)" + self._init_from_ndarray( + image_masks=image_masks, + roi_response_traces=roi_response_traces, + summary_images=summary_images, + background_image_masks=background_image_masks, + background_response_traces=background_response_traces, ) - self._roi_response_dff = dff - if self._roi_response_dff is not None: - assert self._image_masks.shape[-1] == self._roi_response_dff.shape[-1], ( - "Inconsistency between image masks and raw traces. " - "Image masks must be (px, py, num_rois), " - "traces must be (num_frames, num_rois)" - ) - self._roi_response_neuropil = neuropil - if self._roi_response_neuropil is not None: - assert self._image_masks.shape[-1] == self._roi_response_neuropil.shape[-1], ( - "Inconsistency between image masks and raw traces. " - "Image masks must be (px, py, num_rois), " - "traces must be (num_frames, num_rois)" - ) - self._roi_response_deconvolved = deconvolved - if self._roi_response_deconvolved is not None: - assert self._image_masks.shape[-1] == self._roi_response_deconvolved.shape[-1], ( - "Inconsistency between image masks and raw traces. " - "Image masks must be (px, py, num_rois), " - "traces must be (num_frames, num_rois)" - ) - self._kwargs = { - "image_masks": image_masks, - "signal": raw, - "dff": dff, - "neuropil": neuropil, - "deconvolved": deconvolved, - } - else: - raise TypeError("'image_masks' can be a str or a numpy array") - self._movie_dims = movie_dims if movie_dims is not None else image_masks.shape - self._image_mean = mean_image - self._image_correlation = correlation_image - if roi_ids is None: - self._roi_ids = list(np.arange(image_masks.shape[2])) else: - assert all([isinstance(roi_id, (int, np.integer)) for roi_id in roi_ids]), "'roi_ids' must be int!" - self._roi_ids = roi_ids - self._roi_locs = roi_locations - self._sampling_frequency = sampling_frequency - self._channel_names = channel_names - self._rejected_list = rejected_list - self._accepted_list = accepted_lst - - @property - def image_dims(self): - """Return the dimensions of the image. - - Returns - ------- - image_dims: list - The dimensions of the image (num_rois, num_rows, num_columns). - """ - return list(self._image_masks.shape[0:2]) + raise TypeError( + f"'image_masks' must be a PathType (str, pathlib.Path) or a numpy array but got {type(image_masks)}" + ) - def get_accepted_list(self): - if self._accepted_list is None: - return list(range(self.get_num_rois())) + self._image_size = self._image_masks.shape[:2] + self._num_rois = self._image_masks.shape[2] + self._num_frames = list(self._roi_response_traces.values())[0].shape[0] + self._roi_ids = roi_ids if roi_ids is not None else list(np.arange(self._num_rois)) + self._accepted_roi_ids = accepted_roi_ids if accepted_roi_ids is not None else self._roi_ids + self._rejected_roi_ids = ( + rejected_roi_ids if rejected_roi_ids is not None else list(set(self._roi_ids) - set(self._accepted_roi_ids)) + ) + + if roi_locations is not None: + self._roi_locations = roi_locations else: - return self._accepted_list + self._roi_locations = get_default_roi_locations_from_image_masks(self._image_masks) + if background_image_masks is not None: + self._num_background_components = self._background_image_masks.shape[2] + self._background_ids = ( + background_ids if background_ids is not None else list(np.arange(self._num_background_components)) + ) - def get_rejected_list(self): - if self._rejected_list is None: - return [a for a in range(self.get_num_rois()) if a not in set(self.get_accepted_list())] - else: - return self._rejected_list - - @property - def roi_locations(self): - """Returns the center locations (x, y) of each ROI.""" - if self._roi_locs is None: - num_ROIs = self.get_num_rois() - raw_images = self._image_masks - roi_location = np.ndarray([2, num_ROIs], dtype="int") - for i in range(num_ROIs): - temp = np.where(raw_images[:, :, i] == np.amax(raw_images[:, :, i])) - roi_location[:, i] = np.array([np.median(temp[0]), np.median(temp[1])]).T - return roi_location - else: - return self._roi_locs + def _init_from_npy( + self, + image_masks: PathType, + roi_response_traces: dict[str, PathType], + summary_images: Optional[dict[str, PathType]], + background_image_masks: Optional[PathType], + background_response_traces: Optional[dict[str, PathType]], + ): + image_masks = Path(image_masks) + assert image_masks.is_file(), "'image_masks' file does not exist" + assert image_masks.suffix == ".npy", "'image_masks' file is not a numpy file (.npy)" + + self.is_dumpable = True + self._image_masks = np.load(image_masks, mmap_mode="r") + + self._roi_response_traces = {} + for name, trace in roi_response_traces.items(): + assert isinstance( + trace, + get_args(PathType), # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + ), f"Since image_masks is a .npy file, roi response '{name}' must also be an .npy file but got {type(trace)}." + trace = Path(trace) + assert trace.is_file(), f"'{name}' file does not exist" + assert trace.suffix == ".npy", f"'{name}' file is not a numpy file (.npy)" + self._roi_response_traces[name] = np.load(trace, mmap_mode="r") + + if summary_images is not None: + self._summary_images = {} + for name, image in summary_images.items(): + assert isinstance( + image, + get_args( + PathType + ), # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + ), f"Since image_masks is a .npy file, summary image '{name}' must also be an .npy file but got {type(image)}." + image = Path(image) + assert image.is_file(), f"'{name}' file does not exist" + assert image.suffix == ".npy", f"'{name}' file is not a numpy file (.npy)" + self._summary_images[name] = np.load(image, mmap_mode="r") + + if background_image_masks is not None: + assert isinstance( + background_image_masks, + get_args(PathType), # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + ), f"Since image_masks is a .npy file, background image masks must also be a .npy file but got {type(background_image_masks)}." + background_image_masks = Path(background_image_masks) + assert background_image_masks.is_file(), "'background_image_masks' file does not exist" + assert background_image_masks.suffix == ".npy", "'background_image_masks' file is not a numpy file (.npy)" + self._background_image_masks = np.load(background_image_masks, mmap_mode="r") + + if background_response_traces is not None: + self._background_response_traces = {} + for name, trace in background_response_traces.items(): + assert isinstance( + trace, + get_args( + PathType + ), # python 3.9 doesn't support get_instance on a Union of types, so we use get_args + ), f"Since image_masks is a .npy file, background response '{name}' must also be a .npy file but got {type(trace)}." + trace = Path(trace) + assert trace.is_file(), f"'{name}' file does not exist" + assert trace.suffix == ".npy", f"'{name}' file is not a numpy file (.npy)" + self._background_response_traces[name] = np.load(trace, mmap_mode="r") + + def _init_from_ndarray( + self, image_masks, roi_response_traces, summary_images, background_image_masks, background_response_traces + ): + self.is_dumpable = False + self._image_masks = image_masks + + self._roi_response_traces = roi_response_traces + for name, trace in self._roi_response_traces.items(): + assert isinstance( + trace, np.ndarray + ), f"Since image_masks is a numpy array, roi response '{name}' must also be a numpy array but got {type(trace)}." + assert trace.shape[-1] == self._image_masks.shape[-1], ( + f"Inconsistency between image masks and {name} traces. " + f"Image masks must be (num_rows, num_columns, num_rois), " + f"traces must be (num_frames, num_rois)" + ) + if summary_images is not None: + self._summary_images = summary_images + for name, image in self._summary_images.items(): + assert image.shape[:2] == self._image_masks.shape[:2], ( + f"Inconsistency between image masks and {name} images. " + f"Image masks must be (num_rows, num_columns, num_rois), " + f"images must be (num_rows, num_columns)" + ) - @staticmethod - def write_segmentation(segmentation_object, save_path): - """Write a NumpySegmentationExtractor to a .npy file. + if background_image_masks is not None: + assert isinstance( + background_image_masks, np.ndarray + ), f"Since image_masks is a numpy array, background image masks must also be a numpy array but got {type(background_image_masks)}." + self._background_image_masks = background_image_masks + + if background_response_traces is not None: + assert ( + background_image_masks is not None + ), "Background image masks must be provided if background response traces are provided." + self._background_response_traces = background_response_traces + for name, trace in self._background_response_traces.items(): + assert trace.shape[-1] == self._background_image_masks.shape[-1], ( + "Inconsistency between background image masks and background response traces. " + "Background image masks must be (num_rows, num_columns, num_background_components), " + "background response traces must be (num_frames, num_background_components)" + ) - Parameters - ---------- - segmentation_object: NumpySegmentationExtractor - The segmentation extractor object to be written to file. - save_path: str or PathType - Path to .npy file. + def get_image_size(self): + return self._image_size - Notes - ----- - This method is not implemented yet. - """ - raise NotImplementedError + def get_num_frames(self): + return self._num_frames + + def get_sampling_frequency(self) -> float: + return self._sampling_frequency - # defining the abstract class informed methods: def get_roi_ids(self): - if self._roi_ids is None: - return list(range(self.get_num_rois())) - else: - return self._roi_ids + return self._roi_ids - def get_image_size(self): - return self._movie_dims + def get_num_rois(self): + return self._num_rois + + def get_accepted_roi_ids(self) -> list: + return self._accepted_roi_ids + + def get_rejected_roi_ids(self) -> list: + return self._rejected_roi_ids + + def get_roi_locations(self, roi_ids=None): + roi_indices = self.get_roi_indices(roi_ids=roi_ids) + return self._roi_locations[:, roi_indices] + + def get_roi_image_masks(self, roi_ids=None) -> np.ndarray: + if roi_ids is None: + return self._image_masks + roi_indices = self.get_roi_indices(roi_ids=roi_ids) + return self._image_masks[:, :, roi_indices] + + def get_roi_response_traces( + self, + names: Optional[list[str]] = None, + roi_ids: Optional[ArrayType] = None, + start_frame: Optional[IntType] = None, + end_frame: Optional[IntType] = None, + ) -> dict: + names = names if names is not None else list(self._roi_response_traces.keys()) + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + + roi_indices = self.get_roi_indices(roi_ids=roi_ids) + roi_response_traces = { + name: self._roi_response_traces[name][start_frame:end_frame, roi_indices] for name in names + } + return roi_response_traces + + def get_background_ids(self) -> list: + return self._background_ids + + def get_num_background_components(self) -> int: + return self._num_background_components + + def get_background_image_masks(self, background_ids=None) -> np.ndarray: + if background_ids is None: + return self._background_image_masks + all_ids = self.get_background_ids() + background_indices = [all_ids.index(i) for i in background_ids] + return self._background_image_masks[:, :, background_indices] + + def get_background_response_traces( + self, + names: Optional[list[str]] = None, + background_ids: Optional[ArrayType] = None, + start_frame: Optional[IntType] = None, + end_frame: Optional[IntType] = None, + ) -> dict: + names = names if names is not None else list(self._background_response_traces.keys()) + all_ids = self.get_background_ids() + background_ids = background_ids if background_ids is not None else all_ids + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + + background_indices = [all_ids.index(i) for i in background_ids] + background_response_traces = { + name: self._background_response_traces[name][start_frame:end_frame, background_indices] for name in names + } + return background_response_traces + + def get_summary_images(self, names: Optional[list[str]] = None) -> dict: + names = names if names is not None else list(self._summary_images.keys()) + summary_images = {name: self._summary_images[name] for name in names} + return summary_images diff --git a/src/roiextractors/segmentationextractor.py b/src/roiextractors/segmentationextractor.py index 3c2881c1..8e567c5e 100644 --- a/src/roiextractors/segmentationextractor.py +++ b/src/roiextractors/segmentationextractor.py @@ -11,7 +11,7 @@ """ from abc import ABC, abstractmethod -from typing import Union, Optional, Tuple, Iterable, List +from typing import Union, Optional, Tuple, Iterable, List, get_args import numpy as np from numpy.typing import ArrayLike @@ -33,118 +33,136 @@ class SegmentationExtractor(ABC): def __init__(self): """Create a new SegmentationExtractor for a specific data format (unique to each child SegmentationExtractor).""" - self._sampling_frequency = None self._times = None - self._channel_names = ["OpticalChannel"] - self._num_planes = 1 - self._roi_response_raw = None - self._roi_response_dff = None - self._roi_response_neuropil = None - self._roi_response_denoised = None - self._roi_response_deconvolved = None - self._image_correlation = None - self._image_mean = None - self._image_mask = None @abstractmethod - def get_accepted_list(self) -> list: - """Get a list of accepted ROI ids. + def get_image_size(self) -> ArrayType: + """Get frame size of movie (height, width). Returns ------- - accepted_list: list - List of accepted ROI ids. + no_rois: array_like + 2-D array: image height x image width """ pass @abstractmethod - def get_rejected_list(self) -> list: - """Get a list of rejected ROI ids. + def get_num_frames(self) -> int: + """Get the number of frames in the recording (duration of recording). Returns ------- - rejected_list: list - List of rejected ROI ids. + num_frames: int + Number of frames in the recording. """ pass @abstractmethod - def get_image_size(self) -> ArrayType: - """Get frame size of movie (height, width). + def get_sampling_frequency(self) -> float: + """Get the sampling frequency in Hz. Returns ------- - no_rois: array_like - 2-D array: image height x image width + sampling_frequency: float + Sampling frequency of the recording in Hz. """ pass - def get_num_frames(self) -> int: - """Get the number of frames in the recording (duration of recording). + @abstractmethod + def get_roi_ids(self) -> list: + """Get the list of ROI ids. Returns ------- - num_frames: int - Number of frames in the recording. + roi_ids: list + List of roi ids. """ - for trace in self.get_traces_dict().values(): - if trace is not None and len(trace.shape) > 0: - return trace.shape[0] + pass - def get_roi_locations(self, roi_ids=None) -> np.ndarray: - """Get the locations of the Regions of Interest (ROIs). + def get_roi_indices(self, roi_ids: Optional[list] = None) -> list: + """Get the list of ROI indices corresponding to the ROI ids. Parameters ---------- - roi_ids: array_like - A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. + roi_ids: list + List of roi ids. If None, all roi indices are returned. Returns ------- - roi_locs: numpy.ndarray - 2-D array: 2 X no_ROIs. The pixel ids (x,y) where the centroid of the ROI is. + roi_indices: list + List of roi indices. """ if roi_ids is None: - roi_idx_ = list(range(self.get_num_rois())) - else: - all_ids = self.get_roi_ids() - roi_idx_ = [all_ids.index(i) for i in roi_ids] - roi_location = np.zeros([2, len(roi_idx_)], dtype="int") - for c, i in enumerate(roi_idx_): - image_mask = self.get_roi_image_masks(roi_ids=[i]) - temp = np.where(image_mask == np.amax(image_mask)) - roi_location[:, c] = np.array([np.median(temp[0]), np.median(temp[1])]).T - return roi_location + return list(range(self.get_num_rois())) + all_roi_ids = self.get_roi_ids() + roi_indices = [all_roi_ids.index(roi_id) for roi_id in roi_ids] + return roi_indices - def get_roi_ids(self) -> list: - """Get the list of ROI ids. + @abstractmethod + def get_num_rois(self) -> int: + """Get total number of Regions of Interest (ROIs) in the acquired images. Returns ------- - roi_ids: list - List of roi ids. + num_rois: int + The number of ROIs extracted. """ - return list(range(self.get_num_rois())) + pass + + @abstractmethod + def get_accepted_roi_ids(self) -> list: + """Get a list of accepted ROI ids. + Returns + ------- + accepted_roi_ids: list + List of accepted ROI ids. + """ + pass + + @abstractmethod + def get_rejected_roi_ids(self) -> list: + """Get a list of rejected ROI ids. + + Returns + ------- + rejected_roi_ids: list + List of rejected ROI ids. + """ + pass + + @abstractmethod + def get_roi_locations(self, roi_ids=None) -> np.ndarray: + """Get the locations of the Regions of Interest (ROIs). + + Parameters + ---------- + roi_ids: array_like + A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. + + Returns + ------- + roi_locs: numpy.ndarray + 2-D array: 2 X no_ROIs. The pixel ids (x,y) where the centroid of the ROI is. + """ + pass + + @abstractmethod def get_roi_image_masks(self, roi_ids=None) -> np.ndarray: """Get the image masks extracted from segmentation algorithm. Parameters ---------- roi_ids: array_like - A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. + A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. If None, image masks for all + ROIs are returned. Returns ------- image_masks: numpy.ndarray 3-D array(val 0 or 1): image_height X image_width X length(roi_ids) """ - if roi_ids is None: - roi_idx_ = range(self.get_num_rois()) - else: - all_ids = self.get_roi_ids() - roi_idx_ = [all_ids.index(i) for i in roi_ids] - return np.stack([self._image_masks[:, :, k] for k in roi_idx_], 2) + pass def get_roi_pixel_masks(self, roi_ids=None) -> np.array: """Get the weights applied to each of the pixels of the mask. @@ -161,11 +179,37 @@ def get_roi_pixel_masks(self, roi_ids=None) -> np.array: Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of the pixel. """ - if roi_ids is None: - roi_ids = range(self.get_num_rois()) + return _pixel_mask_extractor(image_masks=self.get_roi_image_masks(roi_ids=roi_ids)) + + @abstractmethod + def get_roi_response_traces( + self, + names: Optional[list[str]] = None, + roi_ids: Optional[ArrayType] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + ) -> dict: + """Get the roi response traces. - return _pixel_mask_extractor(self.get_roi_image_masks(roi_ids=roi_ids), roi_ids) + Parameters + ---------- + names: list + List of names of the traces to retrieve. Must be one of {'raw', 'dff', 'deconvolved', 'denoised'}. If None, all traces are returned. + roi_ids: array_like + A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. If None, all ROIs are returned. + start_frame: int + The starting frame of the trace. If None, the trace starts from the beginning. + end_frame: int + The ending frame of the trace. If None, the trace ends at the last frame. + Returns + ------- + traces: dict + Dictionary of traces with key as the name of the trace and value as the trace. + """ + pass + + @abstractmethod def get_background_ids(self) -> list: """Get the list of background components ids. @@ -174,8 +218,20 @@ def get_background_ids(self) -> list: background_components_ids: list List of background components ids. """ - return list(range(self.get_num_background_components())) + pass + + @abstractmethod + def get_num_background_components(self) -> int: + """Get total number of background components in the acquired images. + + Returns + ------- + num_background_components: int + The number of background components extracted. + """ + pass + @abstractmethod def get_background_image_masks(self, background_ids=None) -> np.ndarray: """Get the background image masks extracted from segmentation algorithm. @@ -189,12 +245,7 @@ def get_background_image_masks(self, background_ids=None) -> np.ndarray: background_image_masks: numpy.ndarray 3-D array(val 0 or 1): image_height X image_width X length(background_ids) """ - if background_ids is None: - background_ids_ = range(self.get_num_background_components()) - else: - all_ids = self.get_background_ids() - background_ids_ = [all_ids.index(i) for i in background_ids] - return np.stack([self._background_image_masks[:, :, k] for k in background_ids_], 2) + pass def get_background_pixel_masks(self, background_ids=None) -> np.array: """Get the weights applied to each of the pixels of the mask. @@ -211,177 +262,53 @@ def get_background_pixel_masks(self, background_ids=None) -> np.array: Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of the pixel. """ - if background_ids is None: - background_ids = range(self.get_num_background_components()) - - return _pixel_mask_extractor(self.get_background_image_masks(background_ids=background_ids), background_ids) - - def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): - """Return a new SegmentationExtractor ranging from the start_frame to the end_frame. - - Parameters - ---------- - start_frame: int - The starting frame of the new SegmentationExtractor. - end_frame: int - The ending frame of the new SegmentationExtractor. - - Returns - ------- - frame_slice_segmentation_extractor: FrameSliceSegmentationExtractor - The frame slice segmentation extractor object. - """ - return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame) + return _pixel_mask_extractor(self.get_background_image_masks(background_ids=background_ids)) - def get_traces( + @abstractmethod + def get_background_response_traces( self, - roi_ids: ArrayType = None, + names: Optional[list[str]] = None, + background_ids: Optional[ArrayType] = None, start_frame: Optional[int] = None, end_frame: Optional[int] = None, - name: str = "raw", - ) -> ArrayType: - """Get the traces of each ROI specified by roi_ids. + ) -> dict: + """Get the background response traces. Parameters ---------- - roi_ids: array_like - A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. + names: list + List of names of the traces to retrieve. Must be one of {'background'}. If None, all traces are returned. + background_ids: array_like + A list or 1D array of ids of the background components. Length is the number of background components requested. If None, all background components are returned. start_frame: int - The starting frame of the trace. + The starting frame of the trace. If None, the trace starts from the beginning. end_frame: int - The ending frame of the trace. - name: str - The name of the trace to retrieve ex. 'raw', 'dff', 'neuropil', 'deconvolved' - - Returns - ------- - traces: array_like - 2-D array (ROI x timepoints) - """ - if name not in self.get_traces_dict(): - raise ValueError(f"traces for {name} not found, enter one of {list(self.get_traces_dict().keys())}") - if roi_ids is not None: - all_ids = self.get_roi_ids() - roi_idxs = [all_ids.index(i) for i in roi_ids] - traces = self.get_traces_dict().get(name) - if traces is not None and len(traces.shape) != 0: - idxs = slice(None) if roi_ids is None else roi_idxs - return np.array(traces[start_frame:end_frame, :])[:, idxs] # numpy fancy indexing is quickest - - def get_traces_dict(self) -> dict: - """Get traces as a dictionary with key as the name of the ROiResponseSeries. + The ending frame of the trace. If None, the trace ends at the last frame. Returns ------- - _roi_response_dict: dict - dictionary with key, values representing different types of RoiResponseSeries: - Raw Fluorescence, DeltaFOverF, Denoised, Neuropil, Deconvolved, Background, etc. + traces: dict + Dictionary of traces with key as the name of the trace and value as the trace. """ - return dict( - raw=self._roi_response_raw, - dff=self._roi_response_dff, - neuropil=self._roi_response_neuropil, - deconvolved=self._roi_response_deconvolved, - denoised=self._roi_response_denoised, - ) - - def get_images_dict(self) -> dict: - """Get images as a dictionary with key as the name of the ROIResponseSeries. - - Returns - ------- - _roi_image_dict: dict - dictionary with key, values representing different types of Images used in segmentation: - Mean, Correlation image - """ - return dict(mean=self._image_mean, correlation=self._image_correlation) + pass - def get_image(self, name: str = "correlation") -> ArrayType: - """Get specific images: mean or correlation. + @abstractmethod + def get_summary_images(self, names: Optional[list[str]] = None) -> dict: + """Get summary images. Parameters ---------- - name:str - name of the type of image to retrieve + names: list + List of names of the images to retrieve. Must be one of {'mean', 'correlation'}. If None, all images are returned. Returns ------- - images: numpy.ndarray + summary_images: dict + Dictionary of summary images with key as the name of the image and value as the image. """ - if name not in self.get_images_dict(): - raise ValueError(f"could not find {name} image, enter one of {list(self.get_images_dict().keys())}") - return self.get_images_dict().get(name) - - def get_sampling_frequency(self) -> float: - """Get the sampling frequency in Hz. - - Returns - ------- - sampling_frequency: float - Sampling frequency of the recording in Hz. - """ - if self._sampling_frequency is not None: - return float(self._sampling_frequency) - - return self._sampling_frequency - - def get_num_rois(self) -> int: - """Get total number of Regions of Interest (ROIs) in the acquired images. - - Returns - ------- - num_rois: int - The number of ROIs extracted. - """ - for trace in self.get_traces_dict().values(): - if trace is not None and len(trace.shape) > 0: - return trace.shape[1] - - def get_num_background_components(self) -> int: - """Get total number of background components in the acquired images. - - Returns - ------- - num_background_components: int - The number of background components extracted. - """ - if self._roi_response_neuropil is not None and len(self._roi_response_neuropil.shape) > 0: - return self._roi_response_neuropil.shape[1] - - def get_channel_names(self) -> List[str]: - """Get names of channels in the pipeline. - - Returns - ------- - _channel_names: list - names of channels (str) - """ - return self._channel_names - - def get_num_channels(self) -> int: - """Get number of channels in the pipeline. - - Returns - ------- - num_of_channels: int - number of channels - """ - return len(self._channel_names) - - def get_num_planes(self) -> int: - """Get the default number of planes of imaging for the segmentation extractor. - - Notes - ----- - Defaults to 1 for all but the MultiSegmentationExtractor. - - Returns - ------- - self._num_planes: int - number of planes - """ - return self._num_planes + pass + # TODO: Refactor _times methods from ImagingExtractor and SegmentationExtractor into a BaseExtractor class def set_times(self, times: ArrayType): """Set the recording times in seconds for each frame. @@ -425,21 +352,33 @@ def frame_to_time(self, frames: Union[IntType, ArrayType]) -> Union[FloatType, A else: return self._times[frames] - @staticmethod - def write_segmentation(segmentation_extractor, save_path, overwrite=False): - """Write recording back to the native format. + def frame_slice(self, start_frame: Optional[int] = None, end_frame: Optional[int] = None): + """Return a new ImagingExtractor ranging from the start_frame to the end_frame. Parameters ---------- - segmentation_extractor: [SegmentationExtractor, MultiSegmentationExtractor] - The EXTRACT segmentation object from which an EXTRACT native format - file has to be generated. - save_path: str - path to save the native format. - overwrite: bool - If True, the file is overwritten if existing (default False) + start_frame: int, optional + Start frame index (inclusive). + end_frame: int, optional + End frame index (exclusive). + + Returns + ------- + imaging: FrameSliceImagingExtractor + The sliced ImagingExtractor object. """ - raise NotImplementedError + num_frames = self.get_num_frames() + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else num_frames + assert 0 <= start_frame < num_frames, f"'start_frame' must be in [0, {num_frames}) but got {start_frame}" + assert 0 < end_frame <= num_frames, f"'end_frame' must be in (0, {num_frames}] but got {end_frame}" + assert ( + start_frame <= end_frame + ), f"'start_frame' ({start_frame}) must be less than or equal to 'end_frame' ({end_frame})" + assert isinstance(start_frame, get_args(IntType)), "'start_frame' must be an integer" + assert isinstance(end_frame, get_args(IntType)), "'end_frame' must be an integer" + + return FrameSliceSegmentationExtractor(parent_segmentation=self, start_frame=start_frame, end_frame=end_frame) class FrameSliceSegmentationExtractor(SegmentationExtractor): @@ -454,8 +393,8 @@ class FrameSliceSegmentationExtractor(SegmentationExtractor): def __init__( self, parent_segmentation: SegmentationExtractor, - start_frame: Optional[int] = None, - end_frame: Optional[int] = None, + start_frame: int, + end_frame: int, ): """Create a new FrameSliceSegmentationExtractor from parent SegmentationExtractor. @@ -469,84 +408,85 @@ def __init__( The ending frame of the new SegmentationExtractor. """ self._parent_segmentation = parent_segmentation - self._start_frame = start_frame or 0 - self._end_frame = end_frame or self._parent_segmentation.get_num_frames() + self._start_frame = start_frame + self._end_frame = end_frame self._num_frames = self._end_frame - self._start_frame - if hasattr(self._parent_segmentation, "_image_masks"): # otherwise, do not set attribute at all - self._image_masks = self._parent_segmentation._image_masks - - parent_size = self._parent_segmentation.get_num_frames() - if start_frame is None: - start_frame = 0 - else: - assert 0 <= start_frame < parent_size - if end_frame is None: - end_frame = parent_size - else: - assert 0 < end_frame <= parent_size - assert end_frame > start_frame, "'start_frame' must be smaller than 'end_frame'!" - super().__init__() if getattr(self._parent_segmentation, "_times") is not None: self._times = self._parent_segmentation._times[start_frame:end_frame] - def get_accepted_list(self) -> list: - return self._parent_segmentation.get_accepted_list() - - def get_rejected_list(self) -> list: - return self._parent_segmentation.get_rejected_list() - - def get_traces( - self, - roi_ids: Optional[Iterable[int]] = None, - start_frame: Optional[int] = None, - end_frame: Optional[int] = None, - name: str = "raw", - ) -> np.ndarray: - start_frame = min(start_frame or 0, self._num_frames) - end_frame = min(end_frame or self._num_frames, self._num_frames) - return self._parent_segmentation.get_traces( - roi_ids=roi_ids, - start_frame=start_frame + self._start_frame, - end_frame=end_frame + self._start_frame, - name=name, - ) - - def get_traces_dict(self) -> dict: - return { - trace_name: self._parent_segmentation.get_traces( - start_frame=self._start_frame, end_frame=self._end_frame, name=trace_name - ) - for trace_name, trace in self._parent_segmentation.get_traces_dict().items() - } - def get_image_size(self) -> Tuple[int, int]: - return tuple(self._parent_segmentation.get_image_size()) + return self._parent_segmentation.get_image_size() def get_num_frames(self) -> int: return self._num_frames + def get_sampling_frequency(self) -> float: + return self._parent_segmentation.get_sampling_frequency() + + def get_roi_ids(self) -> list: + return self._parent_segmentation.get_roi_ids() + def get_num_rois(self) -> int: return self._parent_segmentation.get_num_rois() - def get_images_dict(self) -> dict: - return self._parent_segmentation.get_images_dict() + def get_accepted_roi_ids(self) -> list: + return self._parent_segmentation.get_accepted_roi_ids() - def get_image(self, name="correlation"): - return self._parent_segmentation.get_image(name=name) + def get_rejected_roi_ids(self) -> list: + return self._parent_segmentation.get_rejected_roi_ids() - def get_sampling_frequency(self) -> float: - return self._parent_segmentation.get_sampling_frequency() + def get_roi_locations(self, roi_ids=None) -> np.ndarray: + return self._parent_segmentation.get_roi_locations(roi_ids=roi_ids) + + def get_roi_image_masks(self, roi_ids=None) -> np.ndarray: + return self._parent_segmentation.get_roi_image_masks(roi_ids=roi_ids) - def get_channel_names(self) -> list: - return self._parent_segmentation.get_channel_names() + def get_roi_response_traces( + self, + names: Optional[list[str]] = None, + roi_ids: Optional[ArrayType] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + ) -> dict: + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + start_frame_shifted = start_frame + self._start_frame + end_frame_shifted = end_frame + self._start_frame + return self._parent_segmentation.get_roi_response_traces( + names=names, + roi_ids=roi_ids, + start_frame=start_frame_shifted, + end_frame=end_frame_shifted, + ) - def get_num_channels(self) -> int: - return self._parent_segmentation.get_num_channels() + def get_background_ids(self) -> list: + return self._parent_segmentation.get_background_ids() - def get_num_planes(self) -> int: - return self._parent_segmentation.get_num_planes() + def get_num_background_components(self) -> int: + return self._parent_segmentation.get_num_background_components() + + def get_background_image_masks(self, background_ids=None) -> np.ndarray: + return self._parent_segmentation.get_background_image_masks(background_ids=background_ids) + + def get_background_response_traces( + self, + names: Optional[list[str]] = None, + background_ids: Optional[ArrayType] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + ) -> dict: + start_frame = start_frame if start_frame is not None else 0 + end_frame = end_frame if end_frame is not None else self.get_num_frames() + start_frame_shifted = start_frame + self._start_frame + end_frame_shifted = end_frame + self._start_frame + return self._parent_segmentation.get_background_response_traces( + names=names, + background_ids=background_ids, + start_frame=start_frame_shifted, + end_frame=end_frame_shifted, + ) - def get_roi_pixel_masks(self, roi_ids: Optional[ArrayLike] = None) -> List[np.ndarray]: - return self._parent_segmentation.get_roi_pixel_masks(roi_ids=roi_ids) + def get_summary_images(self, names: Optional[list[str]] = None) -> dict: + return self._parent_segmentation.get_summary_images(names=names) diff --git a/src/roiextractors/testing.py b/src/roiextractors/testing.py index 353a2d11..3ea0d707 100644 --- a/src/roiextractors/testing.py +++ b/src/roiextractors/testing.py @@ -182,8 +182,8 @@ def generate_dummy_segmentation_extractor( correlation_image=correlation_image, roi_ids=roi_ids, roi_locations=roi_locations, - accepted_lst=accepeted_list, - rejected_list=rejected_list, + accepted_roi_ids=accepeted_list, + rejected_roi_ids=rejected_list, movie_dims=movie_dims, channel_names=["channel_num_0"], ) diff --git a/tests/mixins/imagingextractor_mixin.py b/tests/mixins/imaging_extractor_mixin.py similarity index 100% rename from tests/mixins/imagingextractor_mixin.py rename to tests/mixins/imaging_extractor_mixin.py diff --git a/tests/mixins/segmentation_extractor_mixin.py b/tests/mixins/segmentation_extractor_mixin.py new file mode 100644 index 00000000..cedec02f --- /dev/null +++ b/tests/mixins/segmentation_extractor_mixin.py @@ -0,0 +1,562 @@ +import pytest +import numpy as np + + +class SegmentationExtractorMixin: + def test_get_image_size(self, segmentation_extractor, expected_image_masks): + image_size = segmentation_extractor.get_image_size() + assert image_size == (expected_image_masks.shape[0], expected_image_masks.shape[1]) + + def test_get_num_frames(self, segmentation_extractor, expected_roi_response_traces): + num_frames = segmentation_extractor.get_num_frames() + first_expected_roi_response_trace = list(expected_roi_response_traces.values())[0] + assert num_frames == first_expected_roi_response_trace.shape[0] + + def test_get_sampling_frequency(self, segmentation_extractor, expected_sampling_frequency): + sampling_frequency = segmentation_extractor.get_sampling_frequency() + assert sampling_frequency == expected_sampling_frequency + + def test_get_roi_ids(self, segmentation_extractor, expected_roi_ids): + roi_ids = segmentation_extractor.get_roi_ids() + np.testing.assert_array_equal(roi_ids, expected_roi_ids) + + def test_get_roi_indices(self, segmentation_extractor, expected_roi_ids): + roi_indices = segmentation_extractor.get_roi_indices() + expected_roi_indices = list(range(len(expected_roi_ids))) + np.testing.assert_array_equal(roi_indices, expected_roi_indices) + + @pytest.mark.parametrize("expected_roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_indices_with_roi_ids(self, segmentation_extractor, expected_roi_ids, expected_roi_indices): + roi_ids = [expected_roi_ids[i] for i in expected_roi_indices] + roi_indices = segmentation_extractor.get_roi_indices(roi_ids=roi_ids) + np.testing.assert_array_equal(roi_indices, expected_roi_indices) + + def test_get_num_rois(self, segmentation_extractor, expected_roi_ids): + num_rois = segmentation_extractor.get_num_rois() + assert num_rois == len(expected_roi_ids) + + def test_get_accepted_roi_ids(self, segmentation_extractor, expected_accepted_list): + accepted_list = segmentation_extractor.get_accepted_roi_ids() + np.testing.assert_array_equal(accepted_list, expected_accepted_list) + + def test_get_rejected_roi_ids(self, segmentation_extractor, expected_rejected_list): + rejected_list = segmentation_extractor.get_rejected_roi_ids() + np.testing.assert_array_equal(rejected_list, expected_rejected_list) + + def test_get_roi_locations(self, segmentation_extractor, expected_roi_locations): + roi_locations = segmentation_extractor.get_roi_locations() + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_locations_with_roi_ids( + self, segmentation_extractor, expected_roi_locations, expected_roi_ids, roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + roi_locations = segmentation_extractor.get_roi_locations(roi_ids=roi_ids) + np.testing.assert_array_equal(roi_locations, expected_roi_locations[:, roi_indices]) + + def test_get_roi_image_masks(self, segmentation_extractor, expected_image_masks): + image_masks = segmentation_extractor.get_roi_image_masks() + np.testing.assert_array_equal(image_masks, expected_image_masks) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_image_masks_with_roi_ids( + self, segmentation_extractor, expected_image_masks, expected_roi_ids, roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + image_masks = segmentation_extractor.get_roi_image_masks(roi_ids=roi_ids) + np.testing.assert_array_equal(image_masks, expected_image_masks[:, :, roi_indices]) + + def test_get_roi_pixel_masks(self, segmentation_extractor, expected_image_masks): + pixel_masks = segmentation_extractor.get_roi_pixel_masks() + assert len(pixel_masks) == expected_image_masks.shape[2] + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_image_masks[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_pixel_masks_with_roi_ids( + self, segmentation_extractor, expected_image_masks, expected_roi_ids, roi_indices + ): + expected_image_masks_indexed = expected_image_masks[:, :, roi_indices] + roi_ids = [expected_roi_ids[i] for i in roi_indices] + pixel_masks = segmentation_extractor.get_roi_pixel_masks(roi_ids=roi_ids) + assert len(pixel_masks) == len(roi_indices) + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_image_masks_indexed[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + def test_get_roi_response_traces(self, segmentation_extractor, expected_roi_response_traces): + roi_response_traces = segmentation_extractor.get_roi_response_traces() + for name, expected_trace in expected_roi_response_traces.items(): + np.testing.assert_array_equal(roi_response_traces[name], expected_trace) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_response_traces_with_roi_ids( + self, segmentation_extractor, expected_roi_response_traces, expected_roi_ids, roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + roi_response_traces = segmentation_extractor.get_roi_response_traces(roi_ids=roi_ids) + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces[name][:, roi_indices] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("start_frame, end_frame", [(0, 1), (1, 3)]) + def test_get_roi_response_traces_with_frames( + self, segmentation_extractor, expected_roi_response_traces, start_frame, end_frame + ): + roi_response_traces = segmentation_extractor.get_roi_response_traces( + start_frame=start_frame, end_frame=end_frame + ) + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces[name][start_frame:end_frame, :] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("names", ([], ["raw"], ["dff"], ["raw", "dff"])) + def test_get_roi_response_traces_with_names(self, segmentation_extractor, expected_roi_response_traces, names): + roi_response_traces = segmentation_extractor.get_roi_response_traces(names=names) + assert list(roi_response_traces.keys()) == names + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces[name] + np.testing.assert_array_equal(trace, expected_trace) + + def test_get_background_ids(self, segmentation_extractor, expected_background_ids): + background_ids = segmentation_extractor.get_background_ids() + np.testing.assert_array_equal(background_ids, expected_background_ids) + + def test_get_num_background_components(self, segmentation_extractor, expected_background_ids): + num_background_components = segmentation_extractor.get_num_background_components() + assert num_background_components == len(expected_background_ids) + + def test_get_background_image_masks(self, segmentation_extractor, expected_background_image_masks): + background_image_masks = segmentation_extractor.get_background_image_masks() + np.testing.assert_array_equal(background_image_masks, expected_background_image_masks) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_image_masks_with_background_ids( + self, segmentation_extractor, expected_background_image_masks, expected_background_ids, background_indices + ): + expected_background_image_masks_indexed = expected_background_image_masks[:, :, background_indices] + background_ids = [expected_background_ids[i] for i in background_indices] + background_image_masks = segmentation_extractor.get_background_image_masks(background_ids=background_ids) + np.testing.assert_array_equal(background_image_masks, expected_background_image_masks_indexed) + + def test_get_background_pixel_masks(self, segmentation_extractor, expected_background_image_masks): + pixel_masks = segmentation_extractor.get_background_pixel_masks() + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_background_image_masks[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_pixel_masks_with_background_ids( + self, segmentation_extractor, expected_background_image_masks, expected_background_ids, background_indices + ): + expected_background_image_masks_indexed = expected_background_image_masks[:, :, background_indices] + background_ids = [expected_background_ids[i] for i in background_indices] + pixel_masks = segmentation_extractor.get_background_pixel_masks(background_ids=background_ids) + assert len(pixel_masks) == len(background_indices) + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_background_image_masks_indexed[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + def test_get_background_response_traces(self, segmentation_extractor, expected_background_response_traces): + background_response_traces = segmentation_extractor.get_background_response_traces() + for name, expected_trace in expected_background_response_traces.items(): + np.testing.assert_array_equal(background_response_traces[name], expected_trace) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_response_traces_with_background_components( + self, + segmentation_extractor, + expected_background_response_traces, + expected_background_ids, + background_indices, + ): + background_ids = [expected_background_ids[i] for i in background_indices] + background_response_traces = segmentation_extractor.get_background_response_traces( + background_ids=background_ids + ) + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces[name][:, background_indices] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("start_frame, end_frame", [(0, 1), (1, 3)]) + def test_get_background_response_traces_with_frames( + self, segmentation_extractor, expected_background_response_traces, start_frame, end_frame + ): + background_response_traces = segmentation_extractor.get_background_response_traces( + start_frame=start_frame, end_frame=end_frame + ) + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces[name][start_frame:end_frame, :] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("names", ([], ["background"])) + def test_get_background_response_traces_with_names( + self, segmentation_extractor, expected_background_response_traces, names + ): + background_response_traces = segmentation_extractor.get_background_response_traces(names=names) + assert list(background_response_traces.keys()) == names + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces[name] + np.testing.assert_array_equal(trace, expected_trace) + + def test_get_summary_images(self, segmentation_extractor, expected_summary_images): + summary_images = segmentation_extractor.get_summary_images() + for name, expected_image in expected_summary_images.items(): + np.testing.assert_array_equal(summary_images[name], expected_image) + + @pytest.mark.parametrize("names", ([], ["mean"], ["correlation"], ["mean", "correlation"])) + def test_get_summary_images_with_names(self, segmentation_extractor, expected_summary_images, names): + summary_images = segmentation_extractor.get_summary_images(names=names) + assert list(summary_images.keys()) == names + for name, image in summary_images.items(): + expected_image = expected_summary_images[name] + np.testing.assert_array_equal(image, expected_image) + + @pytest.mark.parametrize("start_frame, end_frame", [(None, None), (1, 3), (0, 1)]) + def test_frame_slice(self, segmentation_extractor, start_frame, end_frame): + frame_slice_imaging_extractor = segmentation_extractor.frame_slice(start_frame=start_frame, end_frame=end_frame) + start_frame = 0 if start_frame is None else start_frame + end_frame = segmentation_extractor.get_num_frames() if end_frame is None else end_frame + assert frame_slice_imaging_extractor._parent_segmentation is segmentation_extractor + assert frame_slice_imaging_extractor._start_frame == start_frame + assert frame_slice_imaging_extractor._end_frame == end_frame + assert frame_slice_imaging_extractor._num_frames == end_frame - start_frame + + def test_frame_slice_invalid_start_frame(self, segmentation_extractor): + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(start_frame=-1) + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(start_frame=segmentation_extractor.get_num_frames() + 1) + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(start_frame=0.5) + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(start_frame=3, end_frame=1) + + def test_frame_slice_invalid_end_frame(self, segmentation_extractor): + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(end_frame=-1) + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(end_frame=segmentation_extractor.get_num_frames() + 1) + with pytest.raises(AssertionError): + segmentation_extractor.frame_slice(end_frame=0.5) + + +class FrameSliceSegmentationExtractorMixin: + @pytest.fixture(scope="function") + def segmentation_extractor_frame_slice(self, segmentation_extractor, expected_roi_response_traces): + return segmentation_extractor.frame_slice(start_frame=1, end_frame=3) + + @pytest.fixture(scope="function") + def expected_roi_response_traces_frame_slice(self, expected_roi_response_traces): + return {name: trace[1:3, :] for name, trace in expected_roi_response_traces.items()} + + @pytest.fixture(scope="function") + def expected_background_response_traces_frame_slice(self, expected_background_response_traces): + return {name: trace[1:3, :] for name, trace in expected_background_response_traces.items()} + + def test_get_image_size_frame_slice(self, segmentation_extractor_frame_slice, expected_image_masks): + image_size = segmentation_extractor_frame_slice.get_image_size() + assert image_size == (expected_image_masks.shape[0], expected_image_masks.shape[1]) + + def test_get_num_frames_frame_slice( + self, segmentation_extractor_frame_slice, expected_roi_response_traces_frame_slice + ): + num_frames = segmentation_extractor_frame_slice.get_num_frames() + first_expected_roi_response_trace = list(expected_roi_response_traces_frame_slice.values())[0] + assert num_frames == first_expected_roi_response_trace.shape[0] + + def test_get_sampling_frequency_frame_slice(self, segmentation_extractor_frame_slice, expected_sampling_frequency): + sampling_frequency = segmentation_extractor_frame_slice.get_sampling_frequency() + assert sampling_frequency == expected_sampling_frequency + + def test_get_roi_ids_frame_slice(self, segmentation_extractor_frame_slice, expected_roi_ids): + roi_ids = segmentation_extractor_frame_slice.get_roi_ids() + np.testing.assert_array_equal(roi_ids, expected_roi_ids) + + def test_get_roi_indices_frame_slice(self, segmentation_extractor_frame_slice, expected_roi_ids): + roi_indices = segmentation_extractor_frame_slice.get_roi_indices() + expected_roi_indices = list(range(len(expected_roi_ids))) + np.testing.assert_array_equal(roi_indices, expected_roi_indices) + + @pytest.mark.parametrize("expected_roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_indices_with_roi_ids_frame_slice( + self, segmentation_extractor_frame_slice, expected_roi_ids, expected_roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in expected_roi_indices] + roi_indices = segmentation_extractor_frame_slice.get_roi_indices(roi_ids=roi_ids) + np.testing.assert_array_equal(roi_indices, expected_roi_indices) + + def test_get_num_rois_frame_slice(self, segmentation_extractor_frame_slice, expected_roi_ids): + num_rois = segmentation_extractor_frame_slice.get_num_rois() + assert num_rois == len(expected_roi_ids) + + def test_get_accepted_roi_ids_frame_slice(self, segmentation_extractor_frame_slice, expected_accepted_list): + accepted_list = segmentation_extractor_frame_slice.get_accepted_roi_ids() + np.testing.assert_array_equal(accepted_list, expected_accepted_list) + + def test_get_rejected_roi_ids_frame_slice(self, segmentation_extractor_frame_slice, expected_rejected_list): + rejected_list = segmentation_extractor_frame_slice.get_rejected_roi_ids() + np.testing.assert_array_equal(rejected_list, expected_rejected_list) + + def test_get_roi_locations_frame_slice(self, segmentation_extractor_frame_slice, expected_roi_locations): + roi_locations = segmentation_extractor_frame_slice.get_roi_locations() + np.testing.assert_array_equal(roi_locations, expected_roi_locations) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_locations_with_roi_ids_frame_slice( + self, segmentation_extractor_frame_slice, expected_roi_locations, expected_roi_ids, roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + roi_locations = segmentation_extractor_frame_slice.get_roi_locations(roi_ids=roi_ids) + np.testing.assert_array_equal(roi_locations, expected_roi_locations[:, roi_indices]) + + def test_get_roi_image_masks_frame_slice(self, segmentation_extractor_frame_slice, expected_image_masks): + image_masks = segmentation_extractor_frame_slice.get_roi_image_masks() + np.testing.assert_array_equal(image_masks, expected_image_masks) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_image_masks_with_roi_ids_frame_slice( + self, segmentation_extractor_frame_slice, expected_image_masks, expected_roi_ids, roi_indices + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + image_masks = segmentation_extractor_frame_slice.get_roi_image_masks(roi_ids=roi_ids) + np.testing.assert_array_equal(image_masks, expected_image_masks[:, :, roi_indices]) + + def test_get_roi_pixel_masks_frame_slice(self, segmentation_extractor_frame_slice, expected_image_masks): + pixel_masks = segmentation_extractor_frame_slice.get_roi_pixel_masks() + assert len(pixel_masks) == expected_image_masks.shape[2] + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_image_masks[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_pixel_masks_with_roi_ids_frame_slice( + self, segmentation_extractor_frame_slice, expected_image_masks, expected_roi_ids, roi_indices + ): + expected_image_masks_indexed = expected_image_masks[:, :, roi_indices] + roi_ids = [expected_roi_ids[i] for i in roi_indices] + pixel_masks = segmentation_extractor_frame_slice.get_roi_pixel_masks(roi_ids=roi_ids) + assert len(pixel_masks) == len(roi_indices) + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_image_masks_indexed[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + def test_get_roi_response_traces( + self, segmentation_extractor_frame_slice, expected_roi_response_traces_frame_slice + ): + roi_response_traces = segmentation_extractor_frame_slice.get_roi_response_traces() + for name, expected_trace in expected_roi_response_traces_frame_slice.items(): + np.testing.assert_array_equal(roi_response_traces[name], expected_trace) + + @pytest.mark.parametrize("roi_indices", ([], [0], [0, 1], [0, 2])) + def test_get_roi_response_traces_with_roi_ids( + self, + segmentation_extractor_frame_slice, + expected_roi_response_traces_frame_slice, + expected_roi_ids, + roi_indices, + ): + roi_ids = [expected_roi_ids[i] for i in roi_indices] + roi_response_traces = segmentation_extractor_frame_slice.get_roi_response_traces(roi_ids=roi_ids) + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces_frame_slice[name][:, roi_indices] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("start_frame, end_frame", [(0, 1), (1, 3)]) + def test_get_roi_response_traces_with_frames( + self, segmentation_extractor_frame_slice, expected_roi_response_traces_frame_slice, start_frame, end_frame + ): + roi_response_traces = segmentation_extractor_frame_slice.get_roi_response_traces( + start_frame=start_frame, end_frame=end_frame + ) + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces_frame_slice[name][start_frame:end_frame, :] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("names", ([], ["raw"], ["dff"], ["raw", "dff"])) + def test_get_roi_response_traces_with_names( + self, segmentation_extractor_frame_slice, expected_roi_response_traces_frame_slice, names + ): + roi_response_traces = segmentation_extractor_frame_slice.get_roi_response_traces(names=names) + assert list(roi_response_traces.keys()) == names + for name, trace in roi_response_traces.items(): + expected_trace = expected_roi_response_traces_frame_slice[name] + np.testing.assert_array_equal(trace, expected_trace) + + def test_get_background_ids(self, segmentation_extractor_frame_slice, expected_background_ids): + background_ids = segmentation_extractor_frame_slice.get_background_ids() + np.testing.assert_array_equal(background_ids, expected_background_ids) + + def test_get_num_background_components(self, segmentation_extractor_frame_slice, expected_background_ids): + num_background_components = segmentation_extractor_frame_slice.get_num_background_components() + assert num_background_components == len(expected_background_ids) + + def test_get_background_image_masks(self, segmentation_extractor_frame_slice, expected_background_image_masks): + background_image_masks = segmentation_extractor_frame_slice.get_background_image_masks() + np.testing.assert_array_equal(background_image_masks, expected_background_image_masks) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_image_masks_with_background_ids( + self, + segmentation_extractor_frame_slice, + expected_background_image_masks, + expected_background_ids, + background_indices, + ): + expected_background_image_masks_indexed = expected_background_image_masks[:, :, background_indices] + background_ids = [expected_background_ids[i] for i in background_indices] + background_image_masks = segmentation_extractor_frame_slice.get_background_image_masks( + background_ids=background_ids + ) + np.testing.assert_array_equal(background_image_masks, expected_background_image_masks_indexed) + + def test_get_background_pixel_masks(self, segmentation_extractor_frame_slice, expected_background_image_masks): + pixel_masks = segmentation_extractor_frame_slice.get_background_pixel_masks() + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_background_image_masks[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_pixel_masks_with_background_ids( + self, + segmentation_extractor_frame_slice, + expected_background_image_masks, + expected_background_ids, + background_indices, + ): + expected_background_image_masks_indexed = expected_background_image_masks[:, :, background_indices] + background_ids = [expected_background_ids[i] for i in background_indices] + pixel_masks = segmentation_extractor_frame_slice.get_background_pixel_masks(background_ids=background_ids) + assert len(pixel_masks) == len(background_indices) + for i, pixel_mask in enumerate(pixel_masks): + expected_image_mask = expected_background_image_masks_indexed[:, :, i] + expected_locs = np.where(expected_image_mask > 0) + expected_values = expected_image_mask[expected_image_mask > 0] + np.testing.assert_array_equal(pixel_mask[:, 0], expected_locs[0]) + np.testing.assert_array_equal(pixel_mask[:, 1], expected_locs[1]) + np.testing.assert_array_equal(pixel_mask[:, 2], expected_values) + + def test_get_background_response_traces( + self, segmentation_extractor_frame_slice, expected_background_response_traces_frame_slice + ): + background_response_traces = segmentation_extractor_frame_slice.get_background_response_traces() + for name, expected_trace in expected_background_response_traces_frame_slice.items(): + np.testing.assert_array_equal(background_response_traces[name], expected_trace) + + @pytest.mark.parametrize("background_indices", ([], [0], [0, 1], [0, 2])) + def test_get_background_response_traces_with_background_components( + self, + segmentation_extractor_frame_slice, + expected_background_response_traces_frame_slice, + expected_background_ids, + background_indices, + ): + background_ids = [expected_background_ids[i] for i in background_indices] + background_response_traces = segmentation_extractor_frame_slice.get_background_response_traces( + background_ids=background_ids + ) + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces_frame_slice[name][:, background_indices] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("start_frame, end_frame", [(0, 1), (1, 3)]) + def test_get_background_response_traces_with_frames( + self, + segmentation_extractor_frame_slice, + expected_background_response_traces_frame_slice, + start_frame, + end_frame, + ): + background_response_traces = segmentation_extractor_frame_slice.get_background_response_traces( + start_frame=start_frame, end_frame=end_frame + ) + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces_frame_slice[name][start_frame:end_frame, :] + np.testing.assert_array_equal(trace, expected_trace) + + @pytest.mark.parametrize("names", ([], ["background"])) + def test_get_background_response_traces_with_names( + self, segmentation_extractor_frame_slice, expected_background_response_traces_frame_slice, names + ): + background_response_traces = segmentation_extractor_frame_slice.get_background_response_traces(names=names) + assert list(background_response_traces.keys()) == names + for name, trace in background_response_traces.items(): + expected_trace = expected_background_response_traces_frame_slice[name] + np.testing.assert_array_equal(trace, expected_trace) + + def test_get_summary_images(self, segmentation_extractor_frame_slice, expected_summary_images): + summary_images = segmentation_extractor_frame_slice.get_summary_images() + for name, expected_image in expected_summary_images.items(): + np.testing.assert_array_equal(summary_images[name], expected_image) + + @pytest.mark.parametrize("names", ([], ["mean"], ["correlation"], ["mean", "correlation"])) + def test_get_summary_images_with_names(self, segmentation_extractor_frame_slice, expected_summary_images, names): + summary_images = segmentation_extractor_frame_slice.get_summary_images(names=names) + assert list(summary_images.keys()) == names + for name, image in summary_images.items(): + expected_image = expected_summary_images[name] + np.testing.assert_array_equal(image, expected_image) + + @pytest.mark.parametrize("start_frame, end_frame", [(None, None), (1, 2), (0, 1)]) + def test_frame_slice_on_frame_slice(self, segmentation_extractor_frame_slice, start_frame, end_frame): + twice_sliced_segmentation_extractor = segmentation_extractor_frame_slice.frame_slice( + start_frame=start_frame, end_frame=end_frame + ) + start_frame = 0 if start_frame is None else start_frame + end_frame = twice_sliced_segmentation_extractor.get_num_frames() if end_frame is None else end_frame + assert twice_sliced_segmentation_extractor._parent_segmentation is segmentation_extractor_frame_slice + assert twice_sliced_segmentation_extractor._start_frame == start_frame + assert twice_sliced_segmentation_extractor._end_frame == end_frame + assert twice_sliced_segmentation_extractor._num_frames == end_frame - start_frame + + def test_frame_slice_invalid_start_frame_on_frame_slice(self, segmentation_extractor_frame_slice): + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice(start_frame=-1) + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice( + start_frame=segmentation_extractor_frame_slice.get_num_frames() + 1 + ) + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice(start_frame=0.5) + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice(start_frame=3, end_frame=1) + + def test_frame_slice_invalid_end_frame_on_frame_slice(self, segmentation_extractor_frame_slice): + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice(end_frame=-1) + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice( + end_frame=segmentation_extractor_frame_slice.get_num_frames() + 1 + ) + with pytest.raises(AssertionError): + segmentation_extractor_frame_slice.frame_slice(end_frame=0.5) diff --git a/tests/test_minimal/test_numpy_imaging_extractor.py b/tests/test_minimal/test_numpy_imaging_extractor.py index 32f27bdc..332032f7 100644 --- a/tests/test_minimal/test_numpy_imaging_extractor.py +++ b/tests/test_minimal/test_numpy_imaging_extractor.py @@ -1,4 +1,4 @@ -from ..mixins.imagingextractor_mixin import ImagingExtractorMixin, FrameSliceImagingExtractorMixin +from ..mixins.imaging_extractor_mixin import ImagingExtractorMixin, FrameSliceImagingExtractorMixin from roiextractors import NumpyImagingExtractor from roiextractors.testing import generate_dummy_video import pytest diff --git a/tests/test_minimal/test_numpy_segmentation_extractor.py b/tests/test_minimal/test_numpy_segmentation_extractor.py new file mode 100644 index 00000000..2d9ee984 --- /dev/null +++ b/tests/test_minimal/test_numpy_segmentation_extractor.py @@ -0,0 +1,182 @@ +from ..mixins.segmentation_extractor_mixin import SegmentationExtractorMixin, FrameSliceSegmentationExtractorMixin +from roiextractors import NumpySegmentationExtractor +import pytest +import numpy as np + + +@pytest.fixture(scope="module") +def rng(): + seed = 1727293748 # int(datetime.now().timestamp()) at the time of writing + return np.random.default_rng(seed=seed) + + +@pytest.fixture(scope="module") +def num_rows(): + return 25 + + +@pytest.fixture(scope="module") +def num_columns(): + return 25 + + +@pytest.fixture(scope="module") +def num_rois(): + return 10 + + +@pytest.fixture(scope="module") +def num_frames(): + return 100 + + +@pytest.fixture(scope="module") +def num_background_components(): + return 3 + + +@pytest.fixture(scope="module") +def expected_image_masks(rng, num_rows, num_columns, num_rois): + return rng.random((num_rows, num_columns, num_rois)) + + +@pytest.fixture(scope="module") +def expected_roi_response_traces(rng, num_frames, num_rois): + trace_names = ["raw", "dff", "deconvolved", "denoised"] + traces_dict = {name: rng.random((num_frames, num_rois)) for name in trace_names} + return traces_dict + + +@pytest.fixture(scope="module") +def expected_background_response_traces(rng, num_frames, num_background_components): + trace_names = ["background"] + traces_dict = {name: rng.random((num_frames, num_background_components)) for name in trace_names} + return traces_dict + + +@pytest.fixture(scope="module") +def expected_summary_images(rng, num_rows, num_columns): + image_names = ["mean", "correlation"] + summary_images = {name: rng.random((num_rows, num_columns)) for name in image_names} + return summary_images + + +@pytest.fixture(scope="module") +def expected_roi_ids(num_rois): + return list(range(num_rois)) + + +@pytest.fixture(scope="module") +def expected_roi_locations(rng, num_rois, num_rows, num_columns): + roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois) + roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois) + roi_locations = np.vstack((roi_locations_rows, roi_locations_columns)) + return roi_locations + + +@pytest.fixture(scope="module") +def expected_accepted_list(rng, expected_roi_ids, num_rois): + return rng.choice(expected_roi_ids, size=num_rois // 2, replace=False) + + +@pytest.fixture(scope="module") +def expected_rejected_list(expected_roi_ids, expected_accepted_list): + return list(set(expected_roi_ids) - set(expected_accepted_list)) + + +@pytest.fixture(scope="module") +def expected_sampling_frequency(): + return 30.0 + + +@pytest.fixture(scope="module") +def expected_background_ids(num_background_components): + return list(range(num_background_components)) + + +@pytest.fixture(scope="module") +def expected_background_image_masks(rng, num_rows, num_columns, num_background_components): + return rng.random((num_rows, num_columns, num_background_components)) + + +class TestNumpySegmentationExtractor(SegmentationExtractorMixin, FrameSliceSegmentationExtractorMixin): + @pytest.fixture(scope="function") + def segmentation_extractor( + self, + expected_image_masks, + expected_roi_response_traces, + expected_summary_images, + expected_roi_ids, + expected_roi_locations, + expected_accepted_list, + expected_rejected_list, + expected_background_response_traces, + expected_background_ids, + expected_background_image_masks, + expected_sampling_frequency, + ): + return NumpySegmentationExtractor( + image_masks=expected_image_masks, + roi_response_traces=expected_roi_response_traces, + summary_images=expected_summary_images, + roi_ids=expected_roi_ids, + roi_locations=expected_roi_locations, + accepted_roi_ids=expected_accepted_list, + rejected_roi_ids=expected_rejected_list, + sampling_frequency=expected_sampling_frequency, + background_ids=expected_background_ids, + background_image_masks=expected_background_image_masks, + background_response_traces=expected_background_response_traces, + ) + + +class TestNumpySegmentationExtractorFromFile(SegmentationExtractorMixin, FrameSliceSegmentationExtractorMixin): + @pytest.fixture(scope="function") + def segmentation_extractor( + self, + expected_image_masks, + expected_roi_response_traces, + expected_summary_images, + expected_roi_ids, + expected_roi_locations, + expected_accepted_list, + expected_rejected_list, + expected_background_response_traces, + expected_background_ids, + expected_background_image_masks, + expected_sampling_frequency, + tmp_path, + ): + name_to_ndarray = dict( + image_masks=expected_image_masks, + background_image_masks=expected_background_image_masks, + ) + name_to_file_path = {} + for name, ndarray in name_to_ndarray.items(): + file_path = tmp_path / f"{name}.npy" + file_path.parent.mkdir(parents=True, exist_ok=True) + np.save(file_path, ndarray) + name_to_file_path[name] = file_path + name_to_dict_of_ndarrays = dict( + roi_response_traces=expected_roi_response_traces, + background_response_traces=expected_background_response_traces, + summary_images=expected_summary_images, + ) + name_to_dict_of_file_paths = {} + for name, dict_of_ndarrays in name_to_dict_of_ndarrays.items(): + name_to_dict_of_file_paths[name] = {} + for key, ndarray in dict_of_ndarrays.items(): + file_path = tmp_path / f"{name}_{key}.npy" + np.save(file_path, ndarray) + name_to_dict_of_file_paths[name][key] = file_path + + return NumpySegmentationExtractor( + **name_to_file_path, + **name_to_dict_of_file_paths, + roi_ids=expected_roi_ids, + roi_locations=expected_roi_locations, + accepted_roi_ids=expected_accepted_list, + rejected_roi_ids=expected_rejected_list, + sampling_frequency=expected_sampling_frequency, + background_ids=expected_background_ids, + )