diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index a69f704..47017a1 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -9,7 +9,7 @@ import xarray from alphabase.io import tempmmap from spatialdata import SpatialData -from spatialdata.models import PointsModel, TableModel +from spatialdata.models import Image2DModel, PointsModel, TableModel from spatialdata.transformations.transformations import Identity from scportrait.pipeline._base import Logable @@ -19,7 +19,8 @@ get_chunk_size, ) -ChunkSize: TypeAlias = tuple[int, int] +ChunkSize2D: TypeAlias = tuple[int, int] +ChunkSize3D: TypeAlias = tuple[int, int, int] ObjectType: TypeAlias = Literal["images", "labels", "points", "tables"] @@ -143,6 +144,58 @@ def _get_input_image(self, sdata: SpatialData) -> xarray.DataArray: return input_image ## write elements to sdata object + def _write_image_sdata( + self, + image, + image_name: str, + channel_names: list[str] = None, + scale_factors: list[int] = None, + chunks: ChunkSize3D = (1, 1000, 1000), + overwrite=False, + ): + """ + Write the supplied image to the spatialdata object. + + Args: + image (dask.array): Image to be written to the spatialdata object. + image_name (str): Name of the image to be written to the spatialdata object. + channel_names list[str]: List of channel names for the image. Default is None. + scale_factors list[int]: List of scale factors for the image. Default is [2, 4, 8]. This will load the image at 4 different resolutions to allow for fluid visualization. + chunks (tuple): Chunk size for the image. Default is (1, 1000, 1000). + overwrite (bool): Whether to overwrite existing data. Default is False. + """ + + if scale_factors is None: + scale_factors = [2, 4, 8] + if scale_factors is None: + scale_factors = [2, 4, 8] + + _sdata = self._read_sdata() + + if channel_names is None: + channel_names = [f"channel_{i}" for i in range(image.shape[0])] + + # transform to spatialdata image model + transform_original = Identity() + image = Image2DModel.parse( + image, + dims=["c", "y", "x"], + chunks=chunks, + c_coords=channel_names, + scale_factors=scale_factors, + transformations={"global": transform_original}, + rgb=False, + ) + + if overwrite: + self._force_delete_object(_sdata, image_name, "images") + + _sdata.images[image_name] = image + _sdata.write_element(image_name, overwrite=True) + + self.log(f"Image {image_name} written to sdata object.") + self._check_sdata_status() + def _write_segmentation_object_sdata( self, segmentation_object: spLabels2DModel, @@ -177,7 +230,7 @@ def _write_segmentation_sdata( segmentation: xarray.DataArray | np.ndarray, segmentation_label: str, classes: set[str] | None = None, - chunks: ChunkSize = (1000, 1000), + chunks: ChunkSize2D = (1000, 1000), overwrite: bool = False, ) -> None: """Write segmentation data to SpatialData. @@ -268,6 +321,7 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None centroids_object = self._get_centers(_sdata, segmentation_label) self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite) + ## load elements from sdata to a memory mapped array def _load_input_image_to_memmap( self, tmp_dir_abs_path: str | Path, image: np.typing.NDArray[Any] | None = None ) -> str: diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index 0ccb642..ff51dfd 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -842,8 +842,10 @@ def process(self, partial=False, n_cells=None, seed=42): self.log("Loading input images to memory mapped arrays...") start_data_transfer = timeit.default_timer() - self.path_seg_masks = self.project._load_seg_to_memmap(seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path) - self.path_image_data = self.project._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path) + self.path_seg_masks = self.filehandler._load_seg_to_memmap( + seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path + ) + self.path_image_data = self.filehandler._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path) stop_data_transfer = timeit.default_timer() time_data_transfer = stop_data_transfer - start_data_transfer diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index b01beba..bb968c8 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -2,21 +2,22 @@ project ======= -Within scPortrait, all operations are centered around the concept of a ``Project``. A ``Project`` is a python class which manages -all of the scPortrait processing steps and is the central element through which all operations are performed. Each ``Project`` directly -maps to a directory on the file system which contains all of the inputs to a specific scPortrait run as well as the generated outputs. -Depending on the structure of the data that is to be processed a different Project class is required. +At the core of scPortrait is the concept of a `Project`. A `Project` is a Python class that orchestrates all scPortrait processing steps, serving as the central element for all operations. +Each `Project` corresponds to a directory on the file system, which houses the input data for a specific scPortrait run along with the generated outputs. +The choice of the appropriate `Project` class depends on the structure of the data to be processed. -Please see :ref:`here ` for more information. +For more details, refer to :ref:`here `. """ +from __future__ import annotations + import os import re import shutil import tempfile import warnings from time import time -from typing import Literal +from typing import TYPE_CHECKING, Literal import dask.array as darray import numpy as np @@ -28,14 +29,11 @@ from ome_zarr.io import parse_url from ome_zarr.reader import Reader from spatialdata import SpatialData -from spatialdata.models import Image2DModel, PointsModel -from spatialdata.transformations.transformations import Identity from tifffile import imread from scportrait.io import daskmmap from scportrait.pipeline._base import Logable from scportrait.pipeline._utils.sdata_io import sdata_filehandler -from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel from scportrait.pipeline._utils.spatialdata_helper import ( calculate_centroids, generate_region_annotation_lookuptable, @@ -45,10 +43,27 @@ remap_region_annotation_table, ) +if TYPE_CHECKING: + from collections.abc import Callable + class Project(Logable): - CLEAN_LOG = True + """Base implementation for a scPortrait ``project``. + + This class is designed to handle single-timepoint, single-location data, like e.g. whole-slide images. + + Segmentation Methods should be based on :func:`Segmentation ` or :func:`ShardedSegmentation `. + Extraction Methods should be based on :func:`HDF5CellExtraction `. + Attributes: + config (dict): Dictionary containing the config file. + nuc_seg_name (str): Name of the nucleus segmentation object. + cyto_seg_name (str): Name of the cytosol segmentation object. + sdata_path (str): Path to the spatialdata object. + filehander (sdata_filehandler): Filehandler for the spatialdata object which manages all calls or updates to the spatialdata object. + """ + + CLEAN_LOG: bool = True DEFAULT_CONFIG_NAME = "config.yml" DEFAULT_INPUT_IMAGE_NAME = "input_image" DEFAULT_SDATA_FILE = "scportrait.sdata" @@ -57,10 +72,10 @@ class Project(Logable): DEFAULT_PREFIX_FILTERED_SEG = "seg_filtered" DEFAULT_PREFIX_SELECTED_SEG = "seg_selected" - DEFAULT_SEG_NAME_0 = "nucleus" - DEFAULT_SEG_NAME_1 = "cytosol" + DEFAULT_SEG_NAME_0: str = "nucleus" + DEFAULT_SEG_NAME_1: str = "cytosol" - DEFAULT_CENTERS_NAME = "centers_cells" + DEFAULT_CENTERS_NAME: str = "centers_cells" DEFAULT_CHUNK_SIZE = (1, 1000, 1000) @@ -78,15 +93,26 @@ class Project(Logable): def __init__( self, - project_location, - config_path, + project_location: str, + config_path: str, segmentation_f=None, extraction_f=None, featurization_f=None, selection_f=None, - overwrite=False, - debug=False, + overwrite: bool = False, + debug: bool = False, ): + """ + Args: + project_location (str): Path to the project directory. + config_path (str): Path to the config file. + segmentation_f (optional): Segmentation method to be used for the project. + extraction_f (optional): Extraction method to be used for the project. + featurization_f (optional): Featurization method to be used for the project. + selection_f (optional): Selection method to be used for the project. + overwrite (optional): If set to ``True``, will overwrite existing files in the project directory. Default is ``False``. + debug (optional): If set to ``True``, will print additional debug messages/plots. Default is ``False``. + """ super().__init__(directory=project_location, debug=debug) self.project_location = project_location @@ -139,6 +165,12 @@ def __init__( # ==== setup selection === self._setup_selection(selection_f) + def __exit__(self): + self._clear_temp_dir() + + def __del__(self): + self._clear_temp_dir() + ##### Setup Functions ##### def _load_config_from_file(self, file_path): @@ -159,7 +191,15 @@ def _load_config_from_file(self, file_path): except yaml.YAMLError as exc: print(exc) - def _get_config_file(self, config_path: str | None = None): + def _get_config_file(self, config_path: str | None = None) -> None: + """Load the config file for the project. If no config file is passed the default config file in the project directory is loaded. + + Args: + config_path (str, optional): Path to the config file. Default is ``None``. + + Returns: + None: the config dictionary project.config is updated. + """ # load config file self.config_path = os.path.join(self.project_location, self.DEFAULT_CONFIG_NAME) @@ -190,7 +230,15 @@ def _get_config_file(self, config_path: str | None = None): self._load_config_from_file(self.config_path) def _setup_segmentation_f(self, segmentation_f): - if self.segmentation_f is not None: + """Configure the segmentation method for the project. + + Args: + segmentation_f (Callable): Segmentation method to be used for the project. + + Returns: + None: the segmentation method is updated in the project object. + """ + if segmentation_f is not None: if segmentation_f.__name__ not in self.config: raise ValueError(f"Config for {segmentation_f.__name__} is missing from the config file.") @@ -212,6 +260,15 @@ def _setup_segmentation_f(self, segmentation_f): ) def _setup_extraction_f(self, extraction_f): + """Configure the extraction method for the project. + + Args: + extraction_f (Callable): Extraction method to be used for the project. + + Returns: + None: the extraction method is updated in the project object. + """ + if extraction_f is not None: extraction_directory = os.path.join(self.project_location, self.DEFAULT_EXTRACTION_DIR_NAME) @@ -231,6 +288,14 @@ def _setup_extraction_f(self, extraction_f): ) def _setup_featurization_f(self, featurization_f): + """Configure the featurization method for the project. + + Args: + featurization_f (Callable): Featurization method to be used for the project. + + Returns: + None: the featurization method is updated in the project object. + """ if featurization_f is not None: if featurization_f.__name__ not in self.config: raise ValueError(f"Config for {featurization_f.__name__} is missing from the config file") @@ -250,6 +315,14 @@ def _setup_featurization_f(self, featurization_f): ) def _setup_selection(self, selection_f): + """Configure the selection method for the project. + + Args: + selection_f (Callable): Selection method to be used for the project. + + Returns: + None: the selection method is updated in the project object. + """ if self.selection_f is not None: if selection_f.__name__ not in self.config: raise ValueError(f"Config for {selection_f.__name__} is missing from the config file") @@ -268,14 +341,21 @@ def _setup_selection(self, selection_f): filehandler=self.filehandler, ) - def update_featurization_f(self, featurization_f) -> None: + def update_featurization_f(self, featurization_f): """Update the featurization method chosen for the project without reinitializing the entire project. - Parameters - ---------- - featurization_f : class - The featurization method that should be used for the project. + Args: + featurization_f : The featurization method that should be used for the project. + + Returns: + None : the featurization method is updated in the project object. + Examples: + Update the featurization method for a project:: + + from scportrait.pipeline.featurization import CellFeaturizer + + project.update_featurization_f(CellFeaturizer) """ self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}") self._setup_featurization_f(featurization_f) @@ -314,8 +394,18 @@ def _check_chunk_size(self, elem): return elem - def _check_image_dtype(self, image): - """Check if the image dtype is the default image dtype. If not raise a warning.""" + def _check_image_dtype(self, image: np.ndarray) -> None: + """Check if the image dtype is the default image dtype. + + Args: + image (np.ndarray): Image to be checked. + + Returns: + None: If the image dtype is the default image dtype, no action is taken. + + Raises: + Warning: If the image dtype is not the default image dtype. + """ if not image.dtype == self.DEFAULT_IMAGE_DTYPE: Warning( @@ -325,18 +415,27 @@ def _check_image_dtype(self, image): f"Image dtype is not {self.DEFAULT_IMAGE_DTYPE} but insteadt {image.dtype}. The workflow expects images to be of dtype {self.DEFAULT_IMAGE_DTYPE}. Proceeding with the incorrect dtype can lead to unexpected results." ) - def _create_temp_dir(self, path): + def _create_temp_dir(self, path) -> None: """ - Create a temporary directory in the specified directory with the name of the class.s + Create a temporary directory in the specified directory with the name of the class. + + Args: + path (str): Path to the directory where the temporary directory should be created. + + Returns: + None: The temporary directory is created in the specified directory. The path to the temporary directory is stored in the project object as self._tmp_dir_path. + """ path = os.path.join(path, f"{self.__class__.__name__}_") self._tmp_dir = tempfile.TemporaryDirectory(prefix=path) self._tmp_dir_path = self._tmp_dir.name + """str: Path to the temporary directory.""" self.log(f"Initialized temporary directory at {self._tmp_dir_path} for {self.__class__.__name__}") def _clear_temp_dir(self): + """Clear the temporary directory.""" if "_tmp_dir" in self.__dict__.keys(): shutil.rmtree(self._tmp_dir_path, ignore_errors=True) self.log(f"Cleaned up temporary directory at {self._tmp_dir}") @@ -382,6 +481,10 @@ def _ensure_all_labels_habe_cell_ids(self): if not hasattr(self.sdata.labels[keys].attrs, "cell_ids"): self.sdata.labels[keys].attrs["cell_ids"] = get_unique_cell_ids(self.sdata.labels[keys]) + def print_project_status(self): + """Print the current project status.""" + self._check_sdata_status(print_status=True) + def _check_sdata_status(self, print_status=False): if self.sdata is None: self._read_sdata() @@ -412,233 +515,20 @@ def _check_sdata_status(self, print_status=False): return None def _read_sdata(self): - self.sdata = self.filehandler.get_sdata() + self.sdata = self.filehandler.get_sdata() # type: SpatialData self._check_sdata_status() def view_sdata(self): + """Start an interactive napari viewer to look at the sdata object associated with the project. + Note: + This only works in sessions with a visual interface. + """ self.sdata = self.filehandler.get_sdata() # ensure its up to date # open interactive viewer in napari interactive = Interactive(self.sdata) interactive.run() - #### Functions for adding elements to sdata object ######## - def _force_delete_object(self, name: str, type: str): - """ - Force delete an object from the sdata object and the corresponding directory. - - Parameters - ---------- - name : str - Name of the object to be deleted. - type : str - Type of the object to be deleted. Can be either "images", "labels", "points" or "tables". - """ - if name in self.sdata: - del self.sdata[name] - - # define path - path = os.path.join(self.sdata_path, type, name) - if os.path.exists(path): - shutil.rmtree(path, ignore_errors=True) - - def _write_image_sdata( - self, - image, - image_name, - channel_names=None, - scale_factors=None, - chunks=(1, 1000, 1000), - overwrite=False, - ): - """ - Write the supplied image to the spatialdata object. - - Parameters - ---------- - image : dask.array - Image to be written to the spatialdata object. - scale_factors : list - List of scale factors for the image. Default is [2, 4, 8]. This will load the image at 4 different resolutions to allow for fluid visualization. - """ - - if scale_factors is None: - scale_factors = [2, 4, 8] - if self.sdata is None: - self._read_sdata() - - if channel_names is None: - self.channel_names = [f"channel_{i}" for i in range(image.shape[0])] - else: - self.channel_names = channel_names - - # transform to spatialdata image model - transform_original = Identity() - image = Image2DModel.parse( - image, - dims=["c", "y", "x"], - chunks=chunks, - c_coords=self.channel_names, - scale_factors=scale_factors, - transformations={"global": transform_original}, - rgb=False, - ) - - if overwrite: - self._force_delete_object(image_name, "images") - - self.sdata.images[image_name] = image - self.sdata.write_element(image_name, overwrite=True) - - self.log(f"Image {image_name} written to sdata object.") - - # track that input image has been loaded - self.input_image_status = True - - #### Functions for getting elements from sdata object ##### - - def _load_seg_to_memmap(self, seg_name: list[str], tmp_dir_abs_path: str): - """ - Helper function to load segmentation masks from sdata to memory mapped temp arrays for faster access. - Loading happens in a chunked manner to avoid memory issues. - - The function will return the path to the memory mapped array. - - Parameters - ---------- - seg_name : List[str] - List of segmentation element names that should be loaded found in the sdata object. - The segmentation elments need to have the same size. - tmp_dir_abs_path : str - Absolute path to the directory where the memory mapped arrays should be stored. - - Returns - ------- - str - Path to the memory mapped array. Can be reconneted to using the `mmap_array_from_path` - function from the alphabase.io.tempmmap module. - """ - - # ensure all elements are loaded - if self.sdata is None: - self._check_sdata_status() - - # get the segmentation object - assert all( - seg in self.sdata.labels for seg in seg_name - ), "Not all passed segmentation elements found in sdata object." - seg_objects = [self.sdata.labels[seg] for seg in seg_name] - - # get the shape of the segmentation - shapes = [seg.shape for seg in seg_objects] - - Z, Y, X = None, None, None - for shape in shapes: - if len(shape) == 2: - if Y is None: - Y, X = shape - else: - # ensure that all seg masks have the same shape - assert Y == shape[0] - assert X == shape[1] - elif len(shape) == 3: - if Z is None: - Z, Y, X = shape - else: - # ensure that all seg masks have the same shape - assert Z == shape[0] - assert Y == shape[1] - assert X == shape[2] - - # get the number of masks - n_masks = len(seg_objects) - - if Z is not None: - shape = (n_masks, Z, Y, X) - else: - shape = (n_masks, Y, X) - - # initialize empty memory mapped arrays to store the data - path_seg_masks = tempmmap.create_empty_mmap( - shape=shape, - dtype=self.DEFAULT_SEGMENTATION_DTYPE, - tmp_dir_abs_path=tmp_dir_abs_path, - ) - - # create the empty mmap array - seg_masks = tempmmap.mmap_array_from_path(path_seg_masks) - - # load the data into the mmap array in chunks - for i, seg in enumerate(seg_objects): - if Z is not None: - for z in range(Z): - seg_masks[i][z] = seg.data[z].compute() - else: - seg_masks[i] = seg.data.compute() - - # cleanup the cache - self._clear_cache(vars_to_delete=[seg_objects, seg_masks, seg]) - - return path_seg_masks - - def _load_input_image_to_memmap(self, tmp_dir_abs_path: str): - """ - Helper function to load the input image from sdata to memory mapped temp arrays for faster access. - Loading happens in a chunked manner to avoid memory issues. - - The function will return the path to the memory mapped array. - - Parameters - ---------- - tmp_dir_abs_path : str - Absolute path to the directory where the memory mapped arrays should be stored. - - Returns - ------- - str - Path to the memory mapped array. Can be reconneted to using the `mmap_array_from_path` - function from the alphabase.io.tempmmap module. - """ - # ensure all elements are loaded - if self.sdata is None: - self._check_sdata_status() - - if not self.input_image_status: - raise ValueError("Input image not found in sdata object.") - - shape = self.input_image.shape - - # initialize empty memory mapped arrays to store the data - path_input_image = tempmmap.create_empty_mmap( - shape=shape, - dtype=self.DEFAULT_IMAGE_DTYPE, - tmp_dir_abs_path=tmp_dir_abs_path, - ) - - # create the empty mmap array - input_image = tempmmap.mmap_array_from_path(path_input_image) - - # load the data into the mmap array in chunks - Z = None - if len(shape) == 3: - C, Y, X = shape - - elif len(shape) == 4: - Z, C, Y, X = shape - - if Z is not None: - for z in range(Z): - for c in range(C): - input_image[z][c] = self.input_image[z][c].compute() - else: - for c in range(C): - input_image[c] = self.input_image[c].compute() - - # cleanup the cache - self._clear_cache(vars_to_delete=[input_image]) - - return path_input_image - #### Functions to load input data #### def load_input_from_array( self, array: np.ndarray, channel_names: list[str] = None, overwrite: bool | None = None, remap: list[int] = None @@ -697,7 +587,7 @@ def load_input_from_array( self.channel_names = [self.channel_names[i] for i in remap] # write to sdata object - self._write_image_sdata( + self.filehandler._write_image_sdata( image, channel_names=self.channel_names, scale_factors=[2, 4, 8], @@ -853,7 +743,7 @@ def extract_unique_parts(paths: list[str]): channels = daskmmap.dask_array_from_path(temp_image_path) - self._write_image_sdata( + self.filehandler._write_image_sdata( channels, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=self.channel_names, @@ -973,7 +863,7 @@ def load_input_from_sdata( # check coordinate system of input image ### PLACEHOLDER - self._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME) + self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME) self.input_image_status = True # check if a nucleus segmentation exists and if so add it to the sdata object @@ -1205,6 +1095,7 @@ def select( if not self.nuc_seg_status or not self.cyto_seg_status: raise ValueError("No nucleus or cytosol segmentation loaded. Please load a segmentation first.") + assert self.sdata is not None, "No sdata object loaded." assert segmentation_name in self.sdata.labels, f"Segmentation {segmentation_name} not found in sdata object." self.selection_f(