Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Dec 11, 2024
1 parent 0be5e59 commit 8023b1c
Showing 1 changed file with 117 additions and 31 deletions.
148 changes: 117 additions & 31 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
project
=======
At the core of scPortrait is the concept of a `Project`. A `Project` is a Python class that orchestrates all scPortrait processing steps, serving as the central element for all operations.
Each `Project` corresponds to a directory on the file system, which houses the input data for a specific scPortrait run along with the generated outputs.
At the core of scPortrait is the concept of a `Project`. A `Project` is a Python class that orchestrates all scPortrait processing steps, serving as the central element for all operations.
Each `Project` corresponds to a directory on the file system, which houses the input data for a specific scPortrait run along with the generated outputs.
The choice of the appropriate `Project` class depends on the structure of the data to be processed.
For more details, refer to :ref:`here <projects>`.
"""

from __future__ import annotations

import os
import re
import shutil
import tempfile
import warnings
from time import time
from typing import Literal
from typing import TYPE_CHECKING, Literal

import dask.array as darray
import numpy as np
Expand All @@ -27,14 +29,11 @@
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel
from spatialdata.transformations.transformations import Identity
from tifffile import imread

from scportrait.io import daskmmap
from scportrait.pipeline._base import Logable
from scportrait.pipeline._utils.sdata_io import sdata_filehandler
from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel
from scportrait.pipeline._utils.spatialdata_helper import (
calculate_centroids,
generate_region_annotation_lookuptable,
Expand All @@ -44,17 +43,27 @@
remap_region_annotation_table,
)

if TYPE_CHECKING:
from collections.abc import Callable


class Project(Logable):
"""Base implementation for a scPortrait project.
"""Base implementation for a scPortrait ``project``.
This class is designed to handle single-timepoint, single-location data, like e.g. whole-slide images.
Segmentation Methods should be based on :func:`Segmentation <scportrait.pipeline.segmentation.Segmentation>` or :func:`ShardedSegmentation <scportrait.pipeline.segmentation.ShardedSegmentation>` or ShardedSegmentation.
Segmentation Methods should be based on :func:`Segmentation <scportrait.pipeline.segmentation.Segmentation>` or :func:`ShardedSegmentation <scportrait.pipeline.segmentation.ShardedSegmentation>`.
Extraction Methods should be based on :func:`HDF5CellExtraction <scportrait.pipeline.extraction.HDF5CellExtraction>`.
Attributes:
config (dict): Dictionary containing the config file.
nuc_seg_name (str): Name of the nucleus segmentation object.
cyto_seg_name (str): Name of the cytosol segmentation object.
sdata_path (str): Path to the spatialdata object.
filehander (sdata_filehandler): Filehandler for the spatialdata object which manages all calls or updates to the spatialdata object.
"""
CLEAN_LOG = True

CLEAN_LOG: bool = True
DEFAULT_CONFIG_NAME = "config.yml"
DEFAULT_INPUT_IMAGE_NAME = "input_image"
DEFAULT_SDATA_FILE = "scportrait.sdata"
Expand All @@ -63,10 +72,10 @@ class Project(Logable):
DEFAULT_PREFIX_FILTERED_SEG = "seg_filtered"
DEFAULT_PREFIX_SELECTED_SEG = "seg_selected"

DEFAULT_SEG_NAME_0 = "nucleus"
DEFAULT_SEG_NAME_1 = "cytosol"
DEFAULT_SEG_NAME_0: str = "nucleus"
DEFAULT_SEG_NAME_1: str = "cytosol"

DEFAULT_CENTERS_NAME = "centers_cells"
DEFAULT_CENTERS_NAME: str = "centers_cells"

DEFAULT_CHUNK_SIZE = (1, 1000, 1000)

Expand All @@ -84,15 +93,26 @@ class Project(Logable):

def __init__(
self,
project_location,
config_path,
project_location: str,
config_path: str,
segmentation_f=None,
extraction_f=None,
featurization_f=None,
selection_f=None,
overwrite=False,
debug=False,
overwrite: bool = False,
debug: bool = False,
):
"""
Args:
project_location (str): Path to the project directory.
config_path (str): Path to the config file.
segmentation_f (optional): Segmentation method to be used for the project.
extraction_f (optional): Extraction method to be used for the project.
featurization_f (optional): Featurization method to be used for the project.
selection_f (optional): Selection method to be used for the project.
overwrite (optional): If set to ``True``, will overwrite existing files in the project directory. Default is ``False``.
debug (optional): If set to ``True``, will print additional debug messages/plots. Default is ``False``.
"""
super().__init__(directory=project_location, debug=debug)

self.project_location = project_location
Expand Down Expand Up @@ -145,6 +165,12 @@ def __init__(
# ==== setup selection ===
self._setup_selection(selection_f)

def __exit__(self):
self._clear_temp_dir()

def __del__(self):
self._clear_temp_dir()

##### Setup Functions #####

def _load_config_from_file(self, file_path):
Expand All @@ -165,7 +191,15 @@ def _load_config_from_file(self, file_path):
except yaml.YAMLError as exc:
print(exc)

def _get_config_file(self, config_path: str | None = None):
def _get_config_file(self, config_path: str | None = None) -> None:
"""Load the config file for the project. If no config file is passed the default config file in the project directory is loaded.
Args:
config_path (str, optional): Path to the config file. Default is ``None``.
Returns:
None: the config dictionary project.config is updated.
"""
# load config file
self.config_path = os.path.join(self.project_location, self.DEFAULT_CONFIG_NAME)

Expand Down Expand Up @@ -196,7 +230,15 @@ def _get_config_file(self, config_path: str | None = None):
self._load_config_from_file(self.config_path)

def _setup_segmentation_f(self, segmentation_f):
if self.segmentation_f is not None:
"""Configure the segmentation method for the project.
Args:
segmentation_f (Callable): Segmentation method to be used for the project.
Returns:
None: the segmentation method is updated in the project object.
"""
if segmentation_f is not None:
if segmentation_f.__name__ not in self.config:
raise ValueError(f"Config for {segmentation_f.__name__} is missing from the config file.")

Expand All @@ -218,6 +260,15 @@ def _setup_segmentation_f(self, segmentation_f):
)

def _setup_extraction_f(self, extraction_f):
"""Configure the extraction method for the project.
Args:
extraction_f (Callable): Extraction method to be used for the project.
Returns:
None: the extraction method is updated in the project object.
"""

if extraction_f is not None:
extraction_directory = os.path.join(self.project_location, self.DEFAULT_EXTRACTION_DIR_NAME)

Expand All @@ -237,6 +288,14 @@ def _setup_extraction_f(self, extraction_f):
)

def _setup_featurization_f(self, featurization_f):
"""Configure the featurization method for the project.
Args:
featurization_f (Callable): Featurization method to be used for the project.
Returns:
None: the featurization method is updated in the project object.
"""
if featurization_f is not None:
if featurization_f.__name__ not in self.config:
raise ValueError(f"Config for {featurization_f.__name__} is missing from the config file")
Expand All @@ -256,6 +315,14 @@ def _setup_featurization_f(self, featurization_f):
)

def _setup_selection(self, selection_f):
"""Configure the selection method for the project.
Args:
selection_f (Callable): Selection method to be used for the project.
Returns:
None: the selection method is updated in the project object.
"""
if self.selection_f is not None:
if selection_f.__name__ not in self.config:
raise ValueError(f"Config for {selection_f.__name__} is missing from the config file")
Expand All @@ -274,19 +341,20 @@ def _setup_selection(self, selection_f):
filehandler=self.filehandler,
)

def update_featurization_f(self, featurization_f) -> None:
def update_featurization_f(self, featurization_f):
"""Update the featurization method chosen for the project without reinitializing the entire project.
Args:
featurization_f: The featurization method that should be used for the project.
featurization_f : The featurization method that should be used for the project.
Returns:
None: the featurization method is updated in the project object.
None : the featurization method is updated in the project object.
Examples:
Update the featurization method for a project::
from scportrait.pipeline.featurization import CellFeaturizer
project.update_featurization_f(CellFeaturizer)
"""
self.log(f"Replacing current featurization method {self.featurization_f.__class__} with {featurization_f}")
Expand Down Expand Up @@ -326,8 +394,18 @@ def _check_chunk_size(self, elem):

return elem

def _check_image_dtype(self, image):
"""Check if the image dtype is the default image dtype. If not raise a warning."""
def _check_image_dtype(self, image: np.ndarray) -> None:
"""Check if the image dtype is the default image dtype.
Args:
image (np.ndarray): Image to be checked.
Returns:
None: If the image dtype is the default image dtype, no action is taken.
Raises:
Warning: If the image dtype is not the default image dtype.
"""

if not image.dtype == self.DEFAULT_IMAGE_DTYPE:
Warning(
Expand All @@ -337,14 +415,22 @@ def _check_image_dtype(self, image):
f"Image dtype is not {self.DEFAULT_IMAGE_DTYPE} but insteadt {image.dtype}. The workflow expects images to be of dtype {self.DEFAULT_IMAGE_DTYPE}. Proceeding with the incorrect dtype can lead to unexpected results."
)

def _create_temp_dir(self, path):
def _create_temp_dir(self, path) -> None:
"""
Create a temporary directory in the specified directory with the name of the class.s
Create a temporary directory in the specified directory with the name of the class.
Args:
path (str): Path to the directory where the temporary directory should be created.
Returns:
None: The temporary directory is created in the specified directory. The path to the temporary directory is stored in the project object as self._tmp_dir_path.
"""

path = os.path.join(path, f"{self.__class__.__name__}_")
self._tmp_dir = tempfile.TemporaryDirectory(prefix=path)
self._tmp_dir_path = self._tmp_dir.name
"""str: Path to the temporary directory."""

self.log(f"Initialized temporary directory at {self._tmp_dir_path} for {self.__class__.__name__}")

Expand Down Expand Up @@ -396,8 +482,7 @@ def _ensure_all_labels_habe_cell_ids(self):
self.sdata.labels[keys].attrs["cell_ids"] = get_unique_cell_ids(self.sdata.labels[keys])

def print_project_status(self):
"""Print the current project status.
"""
"""Print the current project status."""
self._check_sdata_status(print_status=True)

def _check_sdata_status(self, print_status=False):
Expand Down Expand Up @@ -430,7 +515,7 @@ def _check_sdata_status(self, print_status=False):
return None

def _read_sdata(self):
self.sdata = self.filehandler.get_sdata()
self.sdata = self.filehandler.get_sdata() # type: SpatialData
self._check_sdata_status()

def view_sdata(self):
Expand Down Expand Up @@ -502,7 +587,7 @@ def load_input_from_array(
self.channel_names = [self.channel_names[i] for i in remap]

# write to sdata object
self._write_image_sdata(
self.filehandler._write_image_sdata(
image,
channel_names=self.channel_names,
scale_factors=[2, 4, 8],
Expand Down Expand Up @@ -658,7 +743,7 @@ def extract_unique_parts(paths: list[str]):

channels = daskmmap.dask_array_from_path(temp_image_path)

self._write_image_sdata(
self.filehandler._write_image_sdata(
channels,
self.DEFAULT_INPUT_IMAGE_NAME,
channel_names=self.channel_names,
Expand Down Expand Up @@ -778,7 +863,7 @@ def load_input_from_sdata(
# check coordinate system of input image
### PLACEHOLDER

self._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME)
self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME)
self.input_image_status = True

# check if a nucleus segmentation exists and if so add it to the sdata object
Expand Down Expand Up @@ -1010,6 +1095,7 @@ def select(
if not self.nuc_seg_status or not self.cyto_seg_status:
raise ValueError("No nucleus or cytosol segmentation loaded. Please load a segmentation first.")

assert self.sdata is not None, "No sdata object loaded."
assert segmentation_name in self.sdata.labels, f"Segmentation {segmentation_name} not found in sdata object."

self.selection_f(
Expand Down

0 comments on commit 8023b1c

Please sign in to comment.