From 6accb6c0b6bd2f87dcf7eaa4209ef989259a5192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 14:39:26 +0100 Subject: [PATCH 01/14] deprecate spLabels2DModel --- src/scportrait/pipeline/_utils/sdata_io.py | 17 +-- .../pipeline/_utils/spatialdata_classes.py | 108 ------------------ 2 files changed, 3 insertions(+), 122 deletions(-) delete mode 100644 src/scportrait/pipeline/_utils/spatialdata_classes.py diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 09fb5014..e168e34c 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -9,11 +9,10 @@ import xarray from alphabase.io import tempmmap from spatialdata import SpatialData -from spatialdata.models import Image2DModel, PointsModel, TableModel +from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations.transformations import Identity from scportrait.pipeline._base import Logable -from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel from scportrait.pipeline._utils.spatialdata_helper import ( calculate_centroids, get_chunk_size, @@ -87,13 +86,6 @@ def _read_sdata(self) -> SpatialData: _sdata = self._create_empty_sdata() _sdata.write(self.sdata_path, overwrite=True) - allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"] - for key in _sdata.labels: - if key in allowed_labels: - segmentation_object = _sdata.labels[key] - if not hasattr(segmentation_object.attrs, "cell_ids"): - segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None) - return _sdata def get_sdata(self) -> SpatialData: @@ -249,7 +241,7 @@ def _write_image_sdata( def _write_segmentation_object_sdata( self, - segmentation_object: spLabels2DModel, + segmentation_object: Labels2DModel, segmentation_label: str, classes: set[str] | None = None, overwrite: bool = False, @@ -264,9 +256,6 @@ def _write_segmentation_object_sdata( """ _sdata = self._read_sdata() - if not hasattr(segmentation_object.attrs, "cell_ids"): - segmentation_object = spLabels2DModel().convert(segmentation_object, classes=classes) - if overwrite: self._force_delete_object(_sdata, segmentation_label, "labels") @@ -294,7 +283,7 @@ def _write_segmentation_sdata( overwrite: Whether to overwrite existing data """ transform_original = Identity() - mask = spLabels2DModel.parse( + mask = Labels2DModel.parse( segmentation, dims=["y", "x"], transformations={"global": transform_original}, diff --git a/src/scportrait/pipeline/_utils/spatialdata_classes.py b/src/scportrait/pipeline/_utils/spatialdata_classes.py deleted file mode 100644 index 8af20a3e..00000000 --- a/src/scportrait/pipeline/_utils/spatialdata_classes.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Extended Labels2D Model with cell ID tracking.""" - -from functools import singledispatchmethod -from typing import Any - -from dask.array import unique as DaskUnique -from spatialdata.models import C, Labels2DModel, X, Y, Z, get_axes_names -from spatialdata.transformations.transformations import BaseTransformation -from xarray import DataArray, DataTree -from xarray_schema.components import ( - AttrSchema, - AttrsSchema, -) - -Transform_s = AttrSchema(BaseTransformation, None) - - -class spLabels2DModel(Labels2DModel): - """Extended Labels2DModel that maintains cell IDs in attributes.""" - - # Add attribute that always contains unique classes in labels image - attrs = AttrsSchema( - {"transform": Transform_s}, - {"cell_ids": set[int]}, # More specific type hint for set contents - ) - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize the extended Labels2D model.""" - super().__init__(*args, **kwargs) - - @classmethod - def parse(cls, *args: Any, **kwargs: Any) -> DataArray: - """Parse data and extract cell IDs. - - Returns: - DataArray with cell IDs in attributes - """ - data = super().parse(*args, **kwargs) - data = cls._get_cell_ids(data) - return data - - @staticmethod - def _get_cell_ids(data: DataArray, remove_background: bool = True) -> DataArray: - """Get unique values from labels image. - - Args: - data: Input label array - remove_background: Whether to remove background (0) label - - Returns: - DataArray with cell IDs added to attributes - """ - cell_ids = set(DaskUnique(data.data).compute()) - if remove_background: - cell_ids = cell_ids - {0} # Remove background class - data.attrs["cell_ids"] = cell_ids - return data - - @singledispatchmethod - def convert(self, data: DataTree | DataArray, classes: set[int] | None = None) -> DataTree | DataArray: - """Convert data to include cell IDs. - - Args: - data: Input data to convert - classes: Optional set of class IDs to use - - Returns: - Converted data with cell IDs - - Raises: - ValueError: If data type is not supported - """ - raise ValueError(f"Unsupported data type: {type(data)}. " "Please use .convert() from Labels2DModel instead.") - - @convert.register(DataArray) - def _(self, data: DataArray, classes: set[int] | None = None) -> DataArray: - """Convert DataArray to include cell IDs. - - Args: - data: Input DataArray - classes: Optional set of class IDs to use - - Returns: - DataArray with cell IDs in attributes - """ - if classes is not None: - data.attrs["cell_ids"] = classes - else: - data = self._get_cell_ids(data) - return data - - @convert.register(DataTree) - def _(self, data: DataTree, classes: set[int] | None = None) -> DataTree: - """Convert DataTree to include cell IDs. - - Args: - data: Input DataTree - classes: Optional set of class IDs to use - - Returns: - DataTree with cell IDs in attributes - """ - if classes is not None: - for d in data: - data[d].attrs["cell_ids"] = classes - for d in data: - data[d] = self._get_cell_ids(data[d]) - return data From 4c482d43bfc5f9914e6177d93d9a980cf83a3524 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 14:41:43 +0100 Subject: [PATCH 02/14] remove all occurances of .attrs["cell_ids"] --- src/scportrait/pipeline/extraction.py | 8 ++------ src/scportrait/pipeline/project.py | 6 ++---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index 7a769114..eeada4f3 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -324,9 +324,7 @@ def _get_segmentation_info(self): # this mask will be used to calculate the cell centers if self.n_masks == 2: # perform sanity check that the masks have the same ids - assert ( - _sdata[self.nucleus_key].attrs["cell_ids"] == _sdata[self.cytosol_key].attrs["cell_ids"] - ), "Nucleus and cytosol masks contain different cell ids. Cannot proceed with extraction." + # THIS NEEDS TO BE IMPLEMENTED HERE self.main_segmenation_mask = self.nucleus_key @@ -370,9 +368,7 @@ def _get_centers(self): ), "Cell ids in centers are not unique. Cannot proceed with extraction." # double check that the cell_ids contained in the seg masks match to those from centers - assert set(self.centers_cell_ids) == set( - _sdata[self.main_segmenation_mask].attrs["cell_ids"] - ), "Cell ids from centers do not match those from the segmentation mask. Cannot proceed with extraction." + # THIS NEEDS TO BE IMPLEMENTED HERE def _get_classes_to_extract(self): if self.partial_processing: diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 818267d0..900c6e56 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -954,10 +954,8 @@ def load_input_from_sdata( # ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements # requirements are: # 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids - if self.nuc_seg_status and self.cyto_seg_status: - assert ( - self.sdata[self.nuc_seg_name].attrs["cell_ids"] == self.sdata[self.cyto_seg_name].attrs["cell_ids"] - ), "The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids." + # if self.nuc_seg_status and self.cyto_seg_status: + # THIS NEEDS TO BE IMPLEMENTED HERE # 2. the nucleus segmentation ids and the cytosol segmentation ids need to match # THIS NEEDS TO BE IMPLEMENTED HERE From 088bf42cb7b94851911dd39229df974662f6284c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:28:43 +0100 Subject: [PATCH 03/14] [FIX] ensure channel names are preserved when reading input images from spatialdata object --- src/scportrait/pipeline/project.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 900c6e56..ef2d5112 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -888,17 +888,22 @@ def load_input_from_sdata( if isinstance(image, xarray.DataTree): image_c, image_x, image_y = image.scale0.image.shape + # ensure chunking is correct for scale in image: self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) + + # get channel names + channel_names = image.scale0.image.c.values + elif isinstance(image, xarray.DataArray): - ( - image_c, - image_x, - image_y, - ) = image.shape + image_c, image_x, image_y = image.shape + + # ensure chunking is correct self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) + channel_names = image.c.values + # Reset all transformations if image.attrs.get("transform"): self.log("Image contained transformations which which were removed.") @@ -907,7 +912,7 @@ def load_input_from_sdata( # check coordinate system of input image ### PLACEHOLDER - self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME) + self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=channel_names) # check if a nucleus segmentation exists and if so add it to the sdata object if nucleus_segmentation_name is not None: From a74327d0db136c82358d954ba6c2d26dc3d2e09b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:36:46 +0100 Subject: [PATCH 04/14] [FEATURE] make rechunking optional when reading from an existing spatialdata object it should be possible to turn of rechunking when reading from an existing spatialdata object under the asumption that the chunksize was set logically. This prevents having to recalculate already existing chunks which has a lot of overhead during writing. --- src/scportrait/pipeline/project.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index ef2d5112..9f4b36c1 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -846,6 +846,7 @@ def load_input_from_sdata( overwrite: bool | None = None, keep_all: bool = True, remove_duplicates: bool = True, + rechunk: bool = False, ) -> None: """ Load input image from a spatialdata object. @@ -890,8 +891,9 @@ def load_input_from_sdata( image_c, image_x, image_y = image.scale0.image.shape # ensure chunking is correct - for scale in image: - self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) + if rechunk: + for scale in image: + self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) # get channel names channel_names = image.scale0.image.c.values @@ -900,7 +902,8 @@ def load_input_from_sdata( image_c, image_x, image_y = image.shape # ensure chunking is correct - self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) + if rechunk: + self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D) channel_names = image.c.values @@ -931,7 +934,9 @@ def load_input_from_sdata( mask_y == image_y ), "Nucleus segmentation mask does not match input image size." - self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct + if rechunk: + self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct + self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name) # check if a cytosol segmentation exists and if so add it to the sdata object @@ -951,7 +956,9 @@ def load_input_from_sdata( mask_y == image_y ), "Nucleus segmentation mask does not match input image size." - self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct + if rechunk: + self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct + self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name) self.get_project_status() From f13256ac742ae96b168785088ac9867c88e57fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:53:57 +0100 Subject: [PATCH 05/14] adapt centers file naming convention so that each segmentation mask can have its own centers file --- src/scportrait/pipeline/_base.py | 2 +- src/scportrait/pipeline/_utils/sdata_io.py | 3 ++- src/scportrait/pipeline/extraction.py | 7 ++++--- src/scportrait/pipeline/project.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/scportrait/pipeline/_base.py b/src/scportrait/pipeline/_base.py index d241cdb1..af6824a4 100644 --- a/src/scportrait/pipeline/_base.py +++ b/src/scportrait/pipeline/_base.py @@ -153,7 +153,7 @@ class ProcessingStep(Logable): DEFAULT_SEG_NAME_0 = "nucleus" DEFAULT_SEG_NAME_1 = "cytosol" - DEFAULT_CENTERS_NAME = "centers_cells" + DEFAULT_CENTERS_NAME = "centers" DEFAULT_CHUNK_SIZE = (1, 1000, 1000) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index e168e34c..479538e1 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -359,7 +359,8 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None """ _sdata = self._read_sdata() centroids_object = self._get_centers(_sdata, segmentation_label) - self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite) + centers_name = f"{self.centers_name}_{segmentation_label}" + self._write_points_object_sdata(centroids_object, centers_name, overwrite=overwrite) ## load elements from sdata to a memory mapped array def _load_input_image_to_memmap( diff --git a/src/scportrait/pipeline/extraction.py b/src/scportrait/pipeline/extraction.py index eeada4f3..40fef5b1 100644 --- a/src/scportrait/pipeline/extraction.py +++ b/src/scportrait/pipeline/extraction.py @@ -350,17 +350,18 @@ def _get_centers(self): _sdata = self.filehandler._read_sdata() # calculate centers if they have not been calculated yet - if self.DEFAULT_CENTERS_NAME not in _sdata: + centers_name = f"{self.DEFAULT_CENTERS_NAME}_{self.main_segmenation_mask}" + if centers_name not in _sdata: self.filehandler._add_centers(self.main_segmenation_mask, overwrite=self.overwrite) _sdata = self.filehandler._read_sdata() # reread to ensure we have updated version - centers = _sdata[self.DEFAULT_CENTERS_NAME].values.compute() + centers = _sdata[centers_name].values.compute() # round to int so that we can use them as indices centers = np.round(centers).astype(int) self.centers = centers - self.centers_cell_ids = _sdata[self.DEFAULT_CENTERS_NAME].index.values.compute() + self.centers_cell_ids = _sdata[centers_name].index.values.compute() # ensure that the centers ids are unique assert len(self.centers_cell_ids) == len( diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 9f4b36c1..8d4ee047 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -80,7 +80,7 @@ class Project(Logable): DEFAULT_SEG_NAME_0: str = "nucleus" DEFAULT_SEG_NAME_1: str = "cytosol" - DEFAULT_CENTERS_NAME: str = "centers_cells" + DEFAULT_CENTERS_NAME: str = "centers" DEFAULT_CHUNK_SIZE_3D: ChunkSize3D = (1, 1000, 1000) DEFAULT_CHUNK_SIZE_2D: ChunkSize2D = (1000, 1000) From 6299cc4410ebcbe4bf42dc26cfca38c5a3b64244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:00:40 +0100 Subject: [PATCH 06/14] [FIX] add support for calculating centers on multi-scaled segmentation masks --- src/scportrait/pipeline/_utils/sdata_io.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 479538e1..df2185aa 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -347,7 +347,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod if segmentation_label not in sdata.labels: raise ValueError(f"Segmentation {segmentation_label} not found in sdata object.") - centers = calculate_centroids(sdata.labels[segmentation_label]) + mask = sdata.labels[segmentation_label] + if isinstance(mask, xarray.DataTree): + mask = mask.scale0.image + centers = calculate_centroids(mask) return centers def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None: From 6d16ccfaab6566a78a1800af445453648514447c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:04:39 +0100 Subject: [PATCH 07/14] [FIX] add center calculation to reading segmentation masks from spatialdata --- src/scportrait/pipeline/project.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 8d4ee047..d361bc5b 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -938,6 +938,7 @@ def load_input_from_sdata( self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name) + self.filehandler._add_centers(segmentation_label=self.nuc_seg_name) # check if a cytosol segmentation exists and if so add it to the sdata object if cytosol_segmentation_name is not None: @@ -960,6 +961,7 @@ def load_input_from_sdata( self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name) + self.filehandler._add_centers(segmentation_label=self.cyto_seg_name) self.get_project_status() @@ -1001,12 +1003,6 @@ def load_input_from_sdata( else: self.log(f"No region annotation found for the nucleus segmentation {nucleus_segmentation_name}.") - # add centers of cells for available nucleus map - centroids = calculate_centroids(self.sdata.labels[region_name], coordinate_system="global") - self.filehandler._write_points_object_sdata(centroids, self.DEFAULT_CENTERS_NAME) - - self.centers_status = True - # add cytosol segmentations if available if self.cyto_seg_status: if cytosol_segmentation_name in region_annotation.keys(): From 9524a0d8eb3d11896e136f180052fc6db0a17c07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:07:51 +0100 Subject: [PATCH 08/14] add center calculation to segmentation workflow when the segmentation mask is written to the spatialdata object the centers file for that mask is also generated. This not only tracks unique cell_ids but also their location for easier visualization see #42 --- src/scportrait/pipeline/segmentation/segmentation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/scportrait/pipeline/segmentation/segmentation.py b/src/scportrait/pipeline/segmentation/segmentation.py index 21d9c422..5ef3495f 100644 --- a/src/scportrait/pipeline/segmentation/segmentation.py +++ b/src/scportrait/pipeline/segmentation/segmentation.py @@ -310,12 +310,14 @@ def _save_segmentation_sdata(self, labels, classes, masks=None): self.filehandler._write_segmentation_sdata( labels[ix], self.nuc_seg_name, classes=classes, overwrite=self.overwrite ) + self.filehandler._add_centers(self.nuc_seg_name, overwrite=self.overwrite) if "cytosol" in masks: ix = masks.index("cytosol") self.filehandler._write_segmentation_sdata( labels[ix], self.cyto_seg_name, classes=classes, overwrite=self.overwrite ) + self.filehandler._add_centers(self.cyto_seg_name, overwrite=self.overwrite) def save_map(self, map_name): """Saves newly computed map. From 963036a06951ca3a2501e16bee217025882e0474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:06:49 +0100 Subject: [PATCH 09/14] [FEATURE] add method to automatically add an anndata object to the spatial data object addresses #170 --- src/scportrait/pipeline/_utils/sdata_io.py | 45 ++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index df2185aa..b6b9f4ba 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -8,6 +8,7 @@ import numpy as np import xarray from alphabase.io import tempmmap +from anndata import AnnData from spatialdata import SpatialData from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations.transformations import Identity @@ -313,6 +314,50 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over self.log(f"Points {points_name} written to sdata object.") + def _write_table_sdata( + self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False + ) -> None: + """Write anndata to SpatialData. + + Args: + adata: AnnData object to write + table_name: Name for the table object under which it should be saved + segmentation_mask_name: Name of the segmentation mask that this table annotates + overwrite: Whether to overwrite existing data + + Returns: + None (writes to sdata object) + """ + _sdata = self._read_sdata() + + assert isinstance(adata, AnnData), "Input data must be an AnnData object." + assert segmentation_mask_name in _sdata.labels, "Segmentation mask not found in sdata object." + + # get obs and obs_indices + obs = adata.obs + obs_indices = adata.obs.index + + # sanity checking + assert len(obs_indices) == len(set(obs_indices)), "Instance IDs are not unique." + cell_ids_mask = set(_sdata[f"{self.centers_name}_{segmentation_mask_name}"].index.values.compute()) + assert ( + len(set(obs_indices).difference(cell_ids_mask)) == 0 + ), "Instance IDs do not match segmentation mask cell IDs." + + obs["instance_id"] = obs_indices.astype(int) + obs["region"] = "segmentation_mask_name" + obs["region"] = obs["region"].astype("category") + + adata.obs = obs + table = TableModel.parse( + adata, + region=[segmentation_mask_name], + region_key="region", + instance_key="instance_id", + ) + + self._write_table_object_sdata(table, table_name, overwrite=overwrite) + def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrite: bool = False) -> None: """Write table object to SpatialData. From 5eb3fa8e5950b5c9f6f40d22b5628ad53e6e24e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:14:11 +0100 Subject: [PATCH 10/14] [FIX] utilize filehandler.write_table_sdata method to reduce code overhead --- src/scportrait/pipeline/featurization.py | 64 ++++++------------------ 1 file changed, 14 insertions(+), 50 deletions(-) diff --git a/src/scportrait/pipeline/featurization.py b/src/scportrait/pipeline/featurization.py index 4a357619..296db9ad 100644 --- a/src/scportrait/pipeline/featurization.py +++ b/src/scportrait/pipeline/featurization.py @@ -1269,64 +1269,17 @@ def _write_results_sdata(self, results, mask_type="seg_all"): if self.project.nuc_seg_status: # save nucleus segmentation columns_drop = [x for x in results.columns if self.MASK_NAMES[1] in x] - - _results = results.drop(columns=columns_drop) - _results.set_index("cell_id", inplace=True) - _results.drop(columns=["label"], inplace=True) - - feature_matrix = _results.to_numpy() - var_names = _results.columns - obs_indices = _results.index.astype(str) - - obs = pd.DataFrame() - obs.index = obs_indices - obs["instance_id"] = obs_indices.astype(int) - obs["region"] = f"{mask_type}_{self.MASK_NAMES[0]}" - obs["region"] = obs["region"].astype("category") - - table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) - table = TableModel.parse( - table, - region=[f"{mask_type}_{self.MASK_NAMES[0]}"], - region_key="region", - instance_key="instance_id", - ) - - # define name to save table under - self.label.replace("CellFeaturizer_", "") # remove class name from label to ensure we dont have duplicates + segmentation_name = f"{mask_type}_{self.MASK_NAMES[0]}" if self.channel_selection is not None: table_name = f"{self.__class__.__name__ }_{self.config['channel_selection']}_{self.MASK_NAMES[0]}" else: table_name = f"{self.__class__.__name__ }_{self.MASK_NAMES[0]}" - self.filehandler._write_table_object_sdata(table, table_name, overwrite=self.overwrite_run_path) - if self.project.cyto_seg_status: # save cytosol segmentation columns_drop = [x for x in results.columns if self.MASK_NAMES[0] in x] - - _results = results.drop(columns=columns_drop) - _results.set_index("cell_id", inplace=True) - _results.drop(columns=["label"], inplace=True) - - feature_matrix = _results.to_numpy() - var_names = _results.columns - obs_indices = _results.index.astype(str) - - obs = pd.DataFrame() - obs.index = obs_indices - obs["instance_id"] = obs_indices.astype(int) - obs["region"] = f"{mask_type}_{self.MASK_NAMES[1]}" - obs["region"] = obs["region"].astype("category") - - table = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=obs) - table = TableModel.parse( - table, - region=[f"{mask_type}_{self.MASK_NAMES[1]}"], - region_key="region", - instance_key="instance_id", - ) + segmentation_name = f"{mask_type}_{self.MASK_NAMES[1]}" # define name to save table under if self.channel_selection is not None: @@ -1334,7 +1287,18 @@ def _write_results_sdata(self, results, mask_type="seg_all"): else: table_name = f"{self.__class__.__name__ }_{self.MASK_NAMES[1]}" - self.filehandler._write_table_object_sdata(table, table_name, overwrite=self.overwrite_run_path) + _results = results.drop(columns=columns_drop) + _results.set_index("cell_id", inplace=True) + _results.drop(columns=["label"], inplace=True) + + feature_matrix = _results.to_numpy() + var_names = _results.columns + obs_indices = _results.index.astype(str) + + adata = AnnData(X=feature_matrix, var=pd.DataFrame(index=var_names), obs=pd.DataFrame(index=obs_indices)) + self.filehandler._write_table_sdata( + adata, segmentation_mask_name=segmentation_name, table_name=table_name, overwrite=self.overwrite_run_path + ) class CellFeaturizer(_cellFeaturizerBase): From 95f013bc5829d3665257c44fb6e69c581514d848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:22:16 +0100 Subject: [PATCH 11/14] [FIX] bugs in write_table_sdata method --- src/scportrait/pipeline/_utils/sdata_io.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index b6b9f4ba..9ff8aaac 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -335,17 +335,15 @@ def _write_table_sdata( # get obs and obs_indices obs = adata.obs - obs_indices = adata.obs.index + obs_indices = adata.obs.index.astype(int) # need to ensure int subtype to be able to annotate seg masks # sanity checking assert len(obs_indices) == len(set(obs_indices)), "Instance IDs are not unique." cell_ids_mask = set(_sdata[f"{self.centers_name}_{segmentation_mask_name}"].index.values.compute()) - assert ( - len(set(obs_indices).difference(cell_ids_mask)) == 0 - ), "Instance IDs do not match segmentation mask cell IDs." + assert set(obs_indices).issubset(cell_ids_mask), "Instance IDs do not match segmentation mask cell IDs." - obs["instance_id"] = obs_indices.astype(int) - obs["region"] = "segmentation_mask_name" + obs["instance_id"] = obs_indices + obs["region"] = segmentation_mask_name obs["region"] = obs["region"].astype("category") adata.obs = obs From e2563544d74dfad740b2141a0cc20ede9cf36866 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:22:34 +0100 Subject: [PATCH 12/14] ignore errors from shutil.rmtree call --- src/scportrait/pipeline/project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index d361bc5b..a82e17a1 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -876,7 +876,7 @@ def load_input_from_sdata( # read input sdata object sdata_input = SpatialData.read(sdata_path) if keep_all: - shutil.rmtree(self.sdata_path) # remove old sdata object + shutil.rmtree(self.sdata_path, ignore_errors=True) # remove old sdata object sdata_input.write(self.sdata_path, overwrite=True) del sdata_input sdata_input = self.filehandler.get_sdata() From 99db0d3b1d98da3b1bb711c47e521603865b37bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:22:51 +0100 Subject: [PATCH 13/14] [LOG] add center calculation to log file --- src/scportrait/pipeline/project.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index a82e17a1..357a8500 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -915,6 +915,11 @@ def load_input_from_sdata( # check coordinate system of input image ### PLACEHOLDER + # check channel names + self.log( + f"Found the following channel names in the input image and saving in the spatialdata object: {channel_names}" + ) + self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=channel_names) # check if a nucleus segmentation exists and if so add it to the sdata object @@ -938,6 +943,9 @@ def load_input_from_sdata( self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name) + self.log( + f"Calculating centers for nucleus segmentation mask {self.nuc_seg_name} and adding to spatialdata object." + ) self.filehandler._add_centers(segmentation_label=self.nuc_seg_name) # check if a cytosol segmentation exists and if so add it to the sdata object @@ -961,6 +969,9 @@ def load_input_from_sdata( self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name) + self.log( + f"Calculating centers for cytosol segmentation mask {self.nuc_seg_name} and adding to spatialdata object." + ) self.filehandler._add_centers(segmentation_label=self.cyto_seg_name) self.get_project_status() From 5e765ac81f43cc41743d73456dfa60e1c4ee1edc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 19:19:49 +0100 Subject: [PATCH 14/14] [FIX] addresses #169, also adds short-term fix for #150 --- src/scportrait/pipeline/_utils/sdata_io.py | 76 +++++++++++----------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 9ff8aaac..6e757251 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -187,49 +187,47 @@ def _write_image_sdata( # check if the image is already a multi-scale image if isinstance(image, xarray.DataTree): # if so only validate the model since this means we are getting the image from a spatialdata object already - # image = Image2DModel.validate(image) - # this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865 + # fix until #https://github.com/scverse/spatialdata/issues/528 is resolved + Image2DModel().validate(image) if scale_factors is not None: Warning("Scale factors are ignored when passing a multi-scale image.") - image = image.scale0.image - - if scale_factors is None: - scale_factors = [2, 4, 8] - if scale_factors is None: - scale_factors = [2, 4, 8] - - if isinstance(image, xarray.DataArray): - # if so first validate the model since this means we are getting the image from a spatialdata object already - # then apply the scales transform - # image = Image2DModel.validate(image) - # this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865 - - if channel_names is not None: - Warning( - "Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray." + else: + if scale_factors is None: + scale_factors = [2, 4, 8] + if scale_factors is None: + scale_factors = [2, 4, 8] + + if isinstance(image, xarray.DataArray): + # if so first validate the model since this means we are getting the image from a spatialdata object already + # fix until #https://github.com/scverse/spatialdata/issues/528 is resolved + Image2DModel().validate(image) + + if channel_names is not None: + Warning( + "Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray." + ) + + image = Image2DModel.parse( + image, + scale_factors=scale_factors, + rgb=False, ) - image = Image2DModel.parse( - image, - scale_factors=scale_factors, - rgb=False, - ) - - else: - if channel_names is None: - channel_names = [f"channel_{i}" for i in range(image.shape[0])] - - # transform to spatialdata image model - transform_original = Identity() - image = Image2DModel.parse( - image, - dims=["c", "y", "x"], - chunks=chunks, - c_coords=channel_names, - scale_factors=scale_factors, - transformations={"global": transform_original}, - rgb=False, - ) + else: + if channel_names is None: + channel_names = [f"channel_{i}" for i in range(image.shape[0])] + + # transform to spatialdata image model + transform_original = Identity() + image = Image2DModel.parse( + image, + dims=["c", "y", "x"], + chunks=chunks, + c_coords=channel_names, + scale_factors=scale_factors, + transformations={"global": transform_original}, + rgb=False, + ) if overwrite: self._force_delete_object(_sdata, image_name, "images")