Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typehinting and docstrings #183

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 67 additions & 115 deletions src/scportrait/tools/stitch/_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
stitch
=======

Contains functions to assemble tiled images into fullscale mosaics. Uses out-of-memory computation for the assembly of larger than memory image mosaics.
Functions to assemble tiled images into fullscale mosaics.
Uses out-of-memory computation for the assembly of larger than memory image mosaics.
"""

import os
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING

import numpy as np
from alphabase.io.tempmmap import (
Expand Down Expand Up @@ -66,44 +68,25 @@ def __init__(
overwrite: bool = False,
cache: str = None,
) -> None:
"""
Initialize the Stitcher object.

Parameters:
-----------
input_dir : str
Directory containing the input image tiles.
slidename : str
Name of the slide.
outdir : str
Output directory to save the stitched mosaic.
stitching_channel : str
Name of the channel to be used for stitching.
pattern : str
File pattern to match the image tiles.
overlap : float, optional
Overlap between adjacent image tiles (default is 0.1).
max_shift : float, optional
Maximum allowed shift during alignment (default is 30).
filter_sigma : int, optional
Sigma value for Gaussian filter applied during alignment (default is 0).
do_intensity_rescale : bool or "full_image", optional
Flag to indicate whether to rescale image intensities (default is True). Alternatively, set to "full_image" to rescale the entire image.
rescale_range : tuple or dictionary, optional
If all channels should be rescaled to the same range pass a tuple with the percentiles for rescaleing (default is (1, 99)). Alternatively
a dictionary can be passed with the channel names as keys and the percentiles as values if each channel should be rescaled to a different range.
channel_order : list, optional
Order of channels in the generated output mosaic. If none (default value) the order of the channels is left unchanged.
reader_type : class, optional
Type of reader to use for reading image tiles (default is FilePatternReaderRescale).
orientation : dict, optional
Dictionary specifiying which dimensions of the slide to flip (default is {'flip_x': False, 'flip_y': True}).
plot_QC : bool, optional
Flag to indicate whether to plot quality control (QC) figures (default is True).
overwrite : bool, optional
Flag to indicate whether to overwrite the output directory if it already exists (default is False).
cache : str, optional
Directory to store temporary files during stitching (default is None). If set to none this directory will be created in the outdir.
"""Initialize the Stitcher object.

Args:
input_dir: Directory containing the input image tiles
slidename: Name of the slide
outdir: Output directory to save the stitched mosaic
stitching_channel: Name of the channel to be used for stitching
pattern: File pattern to match the image tiles
overlap: Overlap between adjacent image tiles
max_shift: Maximum allowed shift during alignment
filter_sigma: Sigma value for Gaussian filter applied during alignment
do_intensity_rescale: Flag to rescale image intensities or "full_image" to rescale entire image
rescale_range: Percentiles for intensity rescaling as tuple or dict with channel names as keys
channel_order: Order of channels in output mosaic
reader_type: Type of reader for image tiles
orientation: Dict specifying dimensions to flip {'flip_x', 'flip_y'}
plot_QC: Generate quality control figures
overwrite: Overwrite existing output directory
cache: Directory for temporary files
"""
self._lazy_imports()

Expand Down Expand Up @@ -151,9 +134,7 @@ def __init__(
self.reader = None

def _lazy_imports(self):
"""
Import necessary packages for stitching.
"""
"""Import necessary packages for stitching."""
from ashlar import thumbnail
from ashlar.reg import EdgeAligner, Mosaic
from ashlar.scripts.ashlar import process_axis_flip
Expand All @@ -180,9 +161,7 @@ def __del__(self):
self._clear_cache()

def _create_cache(self):
"""
Create a temporary cache directory for storing intermediate files during stitching.
"""
"""Create a temporary cache directory for storing intermediate files during stitching."""
if self.cache is None:
TEMP_DIR_NAME = redefine_temp_location(self.outdir)
else:
Expand All @@ -191,17 +170,13 @@ def _create_cache(self):
self.TEMP_DIR_NAME = TEMP_DIR_NAME

def _clear_cache(self):
"""
Clear the temporary cache directory.
"""
"""Clear the temporary cache directory."""
if "TEMP_DIR_NAME" in self.__dict__:
if os.path.exists(self.TEMP_DIR_NAME):
shutil.rmtree(self.TEMP_DIR_NAME)

def _initialize_outdir(self):
"""
Initialize the output directory for saving the stitched mosaic.
"""
"""Initialize the output directory for saving the stitched mosaic."""
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
print("Output directory created at: ", self.outdir)
Expand All @@ -216,9 +191,7 @@ def _initialize_outdir(self):
)

def _get_channel_info(self):
"""
Get information about the channels in the image tiles.
"""
"""Get information about the channels in the image tiles."""
# get channel names
self.channel_lookup = self.reader.metadata.channel_map
self.channel_names = list(self.reader.metadata.channel_map.values())
Expand All @@ -240,9 +213,7 @@ def get_stitching_information(self):
print("Output will be written to:", self.outdir)

def _setup_rescaling(self):
"""
Setup image rescaling based on the specified rescale_range.
"""
"""Setup image rescaling based on the specified rescale_range."""
# set up rescaling
if self.do_intensity_rescale:
self.reader.no_rescale_channel = []
Expand All @@ -263,7 +234,8 @@ def _setup_rescaling(self):

if len(missing_channels) > 0:
Warning(
"The rescale_range dictionary does not contain all channels in the experiment. This may lead to unexpected results. For the missing channels rescaling will be turned off."
"The rescale_range dictionary does not contain all channels in the experiment."
"This may lead to unexpected results. For the missing channels rescaling will be turned off."
)

missing_channels = set.difference(self.channel_names, rescale_channels)
Expand All @@ -285,9 +257,7 @@ def _setup_rescaling(self):
self.reader.rescale_range = None

def _reorder_channels(self):
"""
Reorder the channels in the mosaic based on the specified channel_order.
"""
"""Reorder the channels in the mosaic based on the specified channel_order."""
if self.channel_order is None:
self.channels = self.channels
else:
Expand All @@ -301,9 +271,7 @@ def _reorder_channels(self):
self.channels = channels

def _initialize_reader(self):
"""
Initialize the reader for reading image tiles.
"""
"""Initialize the reader for reading image tiles."""
if self.reader_type == self.FilePatternReaderRescale:
self.reader = self.reader_type(
self.input_dir,
Expand All @@ -326,23 +294,19 @@ def _initialize_reader(self):
self._setup_rescaling()

def save_positions(self):
"""
Save the positions of the aligned image tiles.
"""
"""Save the positions of the aligned image tiles."""
positions = self.aligner.positions
np.savetxt(
os.path.join(self.outdir, self.slidename + "_tile_positions.tsv"),
positions,
delimiter="\t",
)

def generate_thumbnail(self, scale=0.05):
"""
Generate a thumbnail of the stitched mosaic.
def generate_thumbnail(self, scale: float | None = 0.05) -> None:
"""Generate a thumbnail of the stitched mosaic.

Args:
scale (float, optional): Scale factor for the thumbnail. Defaults to 0.05.

scale: Scale factor for the thumbnail.
"""
self._initialize_reader()
self.thumbnail = self.ashlar_thumbnail.make_thumbnail(
Expand All @@ -355,7 +319,7 @@ def generate_thumbnail(self, scale=0.05):
rescale_range = {k: self.rescale_range for k in self.channel_names}
rescale = True
elif type(self.rescale_range) is dict:
rescale_range = self.rescale_range[self.stitching_channel]
rescale_range = self.rescale_range[self.stitching_channel] # type: ignore
rescale = True
else:
if not self.do_intensity_rescale:
Expand All @@ -366,11 +330,10 @@ def generate_thumbnail(self, scale=0.05):
self.thumbnail = rescale_image(self.thumbnail, rescale_range)

def _initialize_aligner(self):
"""
Initialize the aligner for aligning the image tiles.
"""Initialize the aligner for aligning the image tiles.

Returns:
aligner (EdgeAligner): Initialized EdgeAligner object.
Initialized EdgeAligner object.
"""
aligner = self.ashlar_EdgeAligner(
self.reader,
Expand All @@ -383,16 +346,12 @@ def _initialize_aligner(self):
return aligner

def plot_qc(self):
"""
Plot quality control (QC) figures for the alignment.
"""
"""Plot quality control (QC) figures for the alignment."""
plot_edge_scatter(self.aligner, self.outdir)
plot_edge_quality(self.aligner, self.outdir)

def _perform_alignment(self):
"""
Perform alignment of the image tiles.
"""
"""Perform alignment of the image tiles."""
# intitialize reader for getting individual image tiles
self._initialize_reader()

Expand All @@ -411,11 +370,10 @@ def _perform_alignment(self):
print("Alignment complete.")

def _initialize_mosaic(self):
"""
Initialize the mosaic object for assembling the image tiles.
"""Initialize the mosaic object for assembling the image tiles.

Returns:
mosaic (Mosaic): Initialized Mosaic object.
Initialized Mosaic object.
"""
mosaic = self.ashlar_Mosaic(
self.aligner,
Expand All @@ -426,9 +384,7 @@ def _initialize_mosaic(self):
return mosaic

def _assemble_mosaic(self):
"""
Assemble the image tiles into a mosaic.
"""
"""Assemble the image tiles into a mosaic."""
# get dimensions of assembled final mosaic
x, y = self.mosaic.shape
shape = (self.n_channels, x, y)
Expand Down Expand Up @@ -472,18 +428,16 @@ def _generate_mosaic(self):
self._assemble_mosaic()

def stitch(self):
"""
Generate the stitched mosaic.
"""
"""Generate the stitched mosaic."""
self._perform_alignment()
self._generate_mosaic()

def write_tif(self, export_xml: bool = True) -> None:
"""
Write the assembled mosaic as TIFF files.
"""Write the assembled mosaic as TIFF files.

Args:
export_xml (bool, optional): Flag to indicate whether to export an XML file for the TIFF files (default is True). This XML file is compatible with loading the generated TIFF files into BIAS.
export_xml: Whether to export an XML file for the TIFF files.
This XML file is compatible with loading the generated TIFF files into BIAS.

Returns:
The assembled mosaic are written to file as TIFF files in the specified output directory.
Expand All @@ -497,16 +451,18 @@ def write_tif(self, export_xml: bool = True) -> None:
if export_xml:
write_xml(filenames, self.channel_names, self.slidename)

def write_ome_zarr(self, downscaling_size=4, n_downscaling_layers=4, chunk_size=(1, 1024, 1024)):
def write_ome_zarr(
self,
downscaling_size: int = 4,
n_downscaling_layers: int = 4,
chunk_size: tuple[int, int, int] = (1, 1024, 1024),
) -> None:
"""Write the assembled mosaic as an OME-Zarr file.

Args:
downscaling_size (int, optional): Downscaling factor for generating lower resolution layers (default is 4).
n_downscaling_layers (int, optional): Number of downscaling layers to generate (default is 4).
chunk_size (tuple, optional): Chunk size for the generated OME-Zarr file (default is (1, 1024, 1024)).

Returns:
None
downscaling_size: Downscaling factor for generating lower resolution layers (default is 4).
n_downscaling_layers: Number of downscaling layers to generate (default is 4).
chunk_size: Chunk size for the generated OME-Zarr file (default is (1, 1024, 1024)).
"""
filepath = os.path.join(self.outdir, f"{self.slidename}.ome.zarr")

Expand All @@ -522,9 +478,7 @@ def write_ome_zarr(self, downscaling_size=4, n_downscaling_layers=4, chunk_size=
)

def write_thumbnail(self):
"""
Write the generated thumbnail as a TIFF file.
"""
"""Write the generated thumbnail as a TIFF file."""
# calculate thumbnail if this has not already been done
if "thumbnail" not in self.__dict__:
self.generate_thumbnail()
Expand All @@ -535,13 +489,12 @@ def write_thumbnail(self):
)
write_tif(filename, self.thumbnail)

def write_spatialdata(self, scale_factors=None):
"""
Write the assembled mosaic as a SpatialData object.
def write_spatialdata(self, scale_factors: list[int] | None = None) -> None:
"""Write the assembled mosaic as a SpatialData object.

Args:
scale_factors (list, optional): List of scale factors for the generated SpatialData object.
Default is [2, 4, 8]. The scale factors are used to generate downsampled versions of the
scale_factors: List of scale factors for the generated SpatialData object.
Defaults to [2, 4, 8]. The scale factors are used to generate downsampled versions of the
image for faster visualization at lower resolutions.
"""
if scale_factors is None:
Expand All @@ -559,8 +512,7 @@ def write_spatialdata(self, scale_factors=None):


class ParallelStitcher(Stitcher):
"""
Class for parallel stitching of image tiles and generating a mosaic. For applicable steps multi-threading is used for faster processing.
"""Class for parallel stitching of image tiles and generating a mosaic. For applicable steps multi-threading is used for faster processing.

Args:
input_dir (str): Directory containing the input image tiles.
Expand Down Expand Up @@ -637,8 +589,7 @@ def __init__(
self.threads = threads

def _initialize_aligner(self):
"""
Initialize the aligner for aligning the image tiles.
"""Initialize the aligner for aligning the image tiles.

Returns:
aligner (ParallelEdgeAligner): Initialized ParallelEdgeAligner object.
Expand Down Expand Up @@ -707,11 +658,12 @@ def _assemble_mosaic(self):
# conver to dask array
self.assembled_mosaic = dask_array_from_path(hdf5_path)

def write_tif_parallel(self, export_xml=True):
def write_tif_parallel(self, export_xml: bool = True):
"""Parallelized version of the write_tif method to write the assembled mosaic as TIFF files.

Args:
export_xml (bool, optional): Flag to indicate whether to export an XML file for the TIFF files (default is True). This XML file is compatible with loading the generarted TIFF files into BIAS.
export_xml: Whether to export an XML file for the TIFF files.
This XML file is compatible with loading the generarted TIFF files into BIAS.

"""

Expand Down
5 changes: 4 additions & 1 deletion src/scportrait/tools/stitch/_utils/filereaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def channel_map(self):


class BioformatsReaderRescale(BioformatsReader):
"""Class for reading images from Bioformats files (e.g. nd2). If desired the images can be rescaled to a certain range while reading."""
"""Class for reading images from Bioformats files (e.g. nd2).

If desired the images can be rescaled to a certain range while reading.
"""

def __init__(self, path, plate=None, well=None, do_rescale=False, no_rescale_channel=None, rescale_range=(1, 99)):
self.path = path
Expand Down