Skip to content

Commit

Permalink
types
Browse files Browse the repository at this point in the history
Signed-off-by: Lukas Heumos <[email protected]>
  • Loading branch information
Zethson committed Feb 12, 2025
1 parent b717064 commit eef85cf
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 118 deletions.
187 changes: 70 additions & 117 deletions src/scportrait/tools/stitch/_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
stitch
=======
Contains functions to assemble tiled images into fullscale mosaics. Uses out-of-memory computation for the assembly of larger than memory image mosaics.
Functions to assemble tiled images into fullscale mosaics.
Uses out-of-memory computation for the assembly of larger than memory image mosaics.
"""

import os
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING

import numpy as np
from alphabase.io.tempmmap import (
create_empty_mmap,
mmap_array_from_path,
redefine_temp_location,
)
from ashlar.reg import EdgeAligner, Mosaic
from tqdm import tqdm

from scportrait.io.daskmmap import dask_array_from_path
Expand Down Expand Up @@ -66,44 +69,25 @@ def __init__(
overwrite: bool = False,
cache: str = None,
) -> None:
"""
Initialize the Stitcher object.
Parameters:
-----------
input_dir : str
Directory containing the input image tiles.
slidename : str
Name of the slide.
outdir : str
Output directory to save the stitched mosaic.
stitching_channel : str
Name of the channel to be used for stitching.
pattern : str
File pattern to match the image tiles.
overlap : float, optional
Overlap between adjacent image tiles (default is 0.1).
max_shift : float, optional
Maximum allowed shift during alignment (default is 30).
filter_sigma : int, optional
Sigma value for Gaussian filter applied during alignment (default is 0).
do_intensity_rescale : bool or "full_image", optional
Flag to indicate whether to rescale image intensities (default is True). Alternatively, set to "full_image" to rescale the entire image.
rescale_range : tuple or dictionary, optional
If all channels should be rescaled to the same range pass a tuple with the percentiles for rescaleing (default is (1, 99)). Alternatively
a dictionary can be passed with the channel names as keys and the percentiles as values if each channel should be rescaled to a different range.
channel_order : list, optional
Order of channels in the generated output mosaic. If none (default value) the order of the channels is left unchanged.
reader_type : class, optional
Type of reader to use for reading image tiles (default is FilePatternReaderRescale).
orientation : dict, optional
Dictionary specifiying which dimensions of the slide to flip (default is {'flip_x': False, 'flip_y': True}).
plot_QC : bool, optional
Flag to indicate whether to plot quality control (QC) figures (default is True).
overwrite : bool, optional
Flag to indicate whether to overwrite the output directory if it already exists (default is False).
cache : str, optional
Directory to store temporary files during stitching (default is None). If set to none this directory will be created in the outdir.
"""Initialize the Stitcher object.
Args:
input_dir: Directory containing the input image tiles
slidename: Name of the slide
outdir: Output directory to save the stitched mosaic
stitching_channel: Name of the channel to be used for stitching
pattern: File pattern to match the image tiles
overlap: Overlap between adjacent image tiles
max_shift: Maximum allowed shift during alignment
filter_sigma: Sigma value for Gaussian filter applied during alignment
do_intensity_rescale: Flag to rescale image intensities or "full_image" to rescale entire image
rescale_range: Percentiles for intensity rescaling as tuple or dict with channel names as keys
channel_order: Order of channels in output mosaic
reader_type: Type of reader for image tiles
orientation: Dict specifying dimensions to flip {'flip_x', 'flip_y'}
plot_QC: Generate quality control figures
overwrite: Overwrite existing output directory
cache: Directory for temporary files
"""
self._lazy_imports()

Expand Down Expand Up @@ -151,9 +135,7 @@ def __init__(
self.reader = None

def _lazy_imports(self):
"""
Import necessary packages for stitching.
"""
"""Import necessary packages for stitching."""
from ashlar import thumbnail
from ashlar.reg import EdgeAligner, Mosaic
from ashlar.scripts.ashlar import process_axis_flip
Expand All @@ -180,9 +162,7 @@ def __del__(self):
self._clear_cache()

def _create_cache(self):
"""
Create a temporary cache directory for storing intermediate files during stitching.
"""
"""Create a temporary cache directory for storing intermediate files during stitching."""
if self.cache is None:
TEMP_DIR_NAME = redefine_temp_location(self.outdir)
else:
Expand All @@ -191,17 +171,13 @@ def _create_cache(self):
self.TEMP_DIR_NAME = TEMP_DIR_NAME

def _clear_cache(self):
"""
Clear the temporary cache directory.
"""
"""Clear the temporary cache directory."""
if "TEMP_DIR_NAME" in self.__dict__:
if os.path.exists(self.TEMP_DIR_NAME):
shutil.rmtree(self.TEMP_DIR_NAME)

def _initialize_outdir(self):
"""
Initialize the output directory for saving the stitched mosaic.
"""
"""Initialize the output directory for saving the stitched mosaic."""
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
print("Output directory created at: ", self.outdir)
Expand All @@ -216,9 +192,7 @@ def _initialize_outdir(self):
)

def _get_channel_info(self):
"""
Get information about the channels in the image tiles.
"""
"""Get information about the channels in the image tiles."""
# get channel names
self.channel_lookup = self.reader.metadata.channel_map
self.channel_names = list(self.reader.metadata.channel_map.values())
Expand All @@ -240,9 +214,7 @@ def get_stitching_information(self):
print("Output will be written to:", self.outdir)

def _setup_rescaling(self):
"""
Setup image rescaling based on the specified rescale_range.
"""
"""Setup image rescaling based on the specified rescale_range."""
# set up rescaling
if self.do_intensity_rescale:
self.reader.no_rescale_channel = []
Expand All @@ -263,7 +235,8 @@ def _setup_rescaling(self):

if len(missing_channels) > 0:
Warning(
"The rescale_range dictionary does not contain all channels in the experiment. This may lead to unexpected results. For the missing channels rescaling will be turned off."
"The rescale_range dictionary does not contain all channels in the experiment."
"This may lead to unexpected results. For the missing channels rescaling will be turned off."
)

missing_channels = set.difference(self.channel_names, rescale_channels)
Expand All @@ -285,9 +258,7 @@ def _setup_rescaling(self):
self.reader.rescale_range = None

def _reorder_channels(self):
"""
Reorder the channels in the mosaic based on the specified channel_order.
"""
"""Reorder the channels in the mosaic based on the specified channel_order."""
if self.channel_order is None:
self.channels = self.channels
else:
Expand All @@ -301,9 +272,7 @@ def _reorder_channels(self):
self.channels = channels

def _initialize_reader(self):
"""
Initialize the reader for reading image tiles.
"""
"""Initialize the reader for reading image tiles."""
if self.reader_type == self.FilePatternReaderRescale:
self.reader = self.reader_type(
self.input_dir,
Expand All @@ -326,23 +295,19 @@ def _initialize_reader(self):
self._setup_rescaling()

def save_positions(self):
"""
Save the positions of the aligned image tiles.
"""
"""Save the positions of the aligned image tiles."""
positions = self.aligner.positions
np.savetxt(
os.path.join(self.outdir, self.slidename + "_tile_positions.tsv"),
positions,
delimiter="\t",
)

def generate_thumbnail(self, scale=0.05):
"""
Generate a thumbnail of the stitched mosaic.
def generate_thumbnail(self, scale: float | None = 0.05) -> None:
"""Generate a thumbnail of the stitched mosaic.
Args:
scale (float, optional): Scale factor for the thumbnail. Defaults to 0.05.
scale: Scale factor for the thumbnail.
"""
self._initialize_reader()
self.thumbnail = self.ashlar_thumbnail.make_thumbnail(
Expand All @@ -355,7 +320,7 @@ def generate_thumbnail(self, scale=0.05):
rescale_range = {k: self.rescale_range for k in self.channel_names}
rescale = True
elif type(self.rescale_range) is dict:
rescale_range = self.rescale_range[self.stitching_channel]
rescale_range = self.rescale_range[self.stitching_channel] # type: ignore
rescale = True
else:
if not self.do_intensity_rescale:
Expand All @@ -365,12 +330,11 @@ def generate_thumbnail(self, scale=0.05):
if rescale:
self.thumbnail = rescale_image(self.thumbnail, rescale_range)

def _initialize_aligner(self):
"""
Initialize the aligner for aligning the image tiles.
def _initialize_aligner(self) -> EdgeAligner:
"""Initialize the aligner for aligning the image tiles.
Returns:
aligner (EdgeAligner): Initialized EdgeAligner object.
Initialized EdgeAligner object.
"""
aligner = self.ashlar_EdgeAligner(
self.reader,
Expand All @@ -383,16 +347,12 @@ def _initialize_aligner(self):
return aligner

def plot_qc(self):
"""
Plot quality control (QC) figures for the alignment.
"""
"""Plot quality control (QC) figures for the alignment."""
plot_edge_scatter(self.aligner, self.outdir)
plot_edge_quality(self.aligner, self.outdir)

def _perform_alignment(self):
"""
Perform alignment of the image tiles.
"""
"""Perform alignment of the image tiles."""
# intitialize reader for getting individual image tiles
self._initialize_reader()

Expand All @@ -410,12 +370,11 @@ def _perform_alignment(self):

print("Alignment complete.")

def _initialize_mosaic(self):
"""
Initialize the mosaic object for assembling the image tiles.
def _initialize_mosaic(self) -> Mosaic:
"""Initialize the mosaic object for assembling the image tiles.
Returns:
mosaic (Mosaic): Initialized Mosaic object.
Initialized Mosaic object.
"""
mosaic = self.ashlar_Mosaic(
self.aligner,
Expand All @@ -426,9 +385,7 @@ def _initialize_mosaic(self):
return mosaic

def _assemble_mosaic(self):
"""
Assemble the image tiles into a mosaic.
"""
"""Assemble the image tiles into a mosaic."""
# get dimensions of assembled final mosaic
x, y = self.mosaic.shape
shape = (self.n_channels, x, y)
Expand Down Expand Up @@ -472,18 +429,16 @@ def _generate_mosaic(self):
self._assemble_mosaic()

def stitch(self):
"""
Generate the stitched mosaic.
"""
"""Generate the stitched mosaic."""
self._perform_alignment()
self._generate_mosaic()

def write_tif(self, export_xml: bool = True) -> None:
"""
Write the assembled mosaic as TIFF files.
"""Write the assembled mosaic as TIFF files.
Args:
export_xml (bool, optional): Flag to indicate whether to export an XML file for the TIFF files (default is True). This XML file is compatible with loading the generated TIFF files into BIAS.
export_xml: Whether to export an XML file for the TIFF files.
This XML file is compatible with loading the generated TIFF files into BIAS.
Returns:
The assembled mosaic are written to file as TIFF files in the specified output directory.
Expand All @@ -497,16 +452,18 @@ def write_tif(self, export_xml: bool = True) -> None:
if export_xml:
write_xml(filenames, self.channel_names, self.slidename)

def write_ome_zarr(self, downscaling_size=4, n_downscaling_layers=4, chunk_size=(1, 1024, 1024)):
def write_ome_zarr(
self,
downscaling_size: int = 4,
n_downscaling_layers: int = 4,
chunk_size: tuple[int, int, int] = (1, 1024, 1024),
) -> None:
"""Write the assembled mosaic as an OME-Zarr file.
Args:
downscaling_size (int, optional): Downscaling factor for generating lower resolution layers (default is 4).
n_downscaling_layers (int, optional): Number of downscaling layers to generate (default is 4).
chunk_size (tuple, optional): Chunk size for the generated OME-Zarr file (default is (1, 1024, 1024)).
Returns:
None
downscaling_size: Downscaling factor for generating lower resolution layers (default is 4).
n_downscaling_layers: Number of downscaling layers to generate (default is 4).
chunk_size: Chunk size for the generated OME-Zarr file (default is (1, 1024, 1024)).
"""
filepath = os.path.join(self.outdir, f"{self.slidename}.ome.zarr")

Expand All @@ -522,9 +479,7 @@ def write_ome_zarr(self, downscaling_size=4, n_downscaling_layers=4, chunk_size=
)

def write_thumbnail(self):
"""
Write the generated thumbnail as a TIFF file.
"""
"""Write the generated thumbnail as a TIFF file."""
# calculate thumbnail if this has not already been done
if "thumbnail" not in self.__dict__:
self.generate_thumbnail()
Expand All @@ -535,13 +490,12 @@ def write_thumbnail(self):
)
write_tif(filename, self.thumbnail)

def write_spatialdata(self, scale_factors=None):
"""
Write the assembled mosaic as a SpatialData object.
def write_spatialdata(self, scale_factors: list[int] | None = None) -> None:
"""Write the assembled mosaic as a SpatialData object.
Args:
scale_factors (list, optional): List of scale factors for the generated SpatialData object.
Default is [2, 4, 8]. The scale factors are used to generate downsampled versions of the
scale_factors: List of scale factors for the generated SpatialData object.
Defaults to [2, 4, 8]. The scale factors are used to generate downsampled versions of the
image for faster visualization at lower resolutions.
"""
if scale_factors is None:
Expand All @@ -559,8 +513,7 @@ def write_spatialdata(self, scale_factors=None):


class ParallelStitcher(Stitcher):
"""
Class for parallel stitching of image tiles and generating a mosaic. For applicable steps multi-threading is used for faster processing.
"""Class for parallel stitching of image tiles and generating a mosaic. For applicable steps multi-threading is used for faster processing.
Args:
input_dir (str): Directory containing the input image tiles.
Expand Down Expand Up @@ -637,8 +590,7 @@ def __init__(
self.threads = threads

def _initialize_aligner(self):
"""
Initialize the aligner for aligning the image tiles.
"""Initialize the aligner for aligning the image tiles.
Returns:
aligner (ParallelEdgeAligner): Initialized ParallelEdgeAligner object.
Expand Down Expand Up @@ -707,11 +659,12 @@ def _assemble_mosaic(self):
# conver to dask array
self.assembled_mosaic = dask_array_from_path(hdf5_path)

def write_tif_parallel(self, export_xml=True):
def write_tif_parallel(self, export_xml: bool = True):
"""Parallelized version of the write_tif method to write the assembled mosaic as TIFF files.
Args:
export_xml (bool, optional): Flag to indicate whether to export an XML file for the TIFF files (default is True). This XML file is compatible with loading the generarted TIFF files into BIAS.
export_xml: Whether to export an XML file for the TIFF files.
This XML file is compatible with loading the generarted TIFF files into BIAS.
"""

Expand Down
Loading

0 comments on commit eef85cf

Please sign in to comment.