Skip to content

Commit

Permalink
Merge pull request #171 from MannLabs/deprecate_cell_id_tracking
Browse files Browse the repository at this point in the history
improve spatialdata file handling
  • Loading branch information
sophiamaedler authored Feb 11, 2025
2 parents aac3456 + 5e765ac commit b717064
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 246 deletions.
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class ProcessingStep(Logable):
DEFAULT_SEG_NAME_0 = "nucleus"
DEFAULT_SEG_NAME_1 = "cytosol"

DEFAULT_CENTERS_NAME = "centers_cells"
DEFAULT_CENTERS_NAME = "centers"

DEFAULT_CHUNK_SIZE = (1, 1000, 1000)

Expand Down
144 changes: 89 additions & 55 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import numpy as np
import xarray
from alphabase.io import tempmmap
from anndata import AnnData
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel, TableModel
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel
from spatialdata.transformations.transformations import Identity

from scportrait.pipeline._base import Logable
from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel
from scportrait.pipeline._utils.spatialdata_helper import (
calculate_centroids,
get_chunk_size,
Expand Down Expand Up @@ -87,13 +87,6 @@ def _read_sdata(self) -> SpatialData:
_sdata = self._create_empty_sdata()
_sdata.write(self.sdata_path, overwrite=True)

allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
for key in _sdata.labels:
if key in allowed_labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)

return _sdata

def get_sdata(self) -> SpatialData:
Expand Down Expand Up @@ -194,49 +187,47 @@ def _write_image_sdata(
# check if the image is already a multi-scale image
if isinstance(image, xarray.DataTree):
# if so only validate the model since this means we are getting the image from a spatialdata object already
# image = Image2DModel.validate(image)
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
Image2DModel().validate(image)
if scale_factors is not None:
Warning("Scale factors are ignored when passing a multi-scale image.")
image = image.scale0.image

if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

if isinstance(image, xarray.DataArray):
# if so first validate the model since this means we are getting the image from a spatialdata object already
# then apply the scales transform
# image = Image2DModel.validate(image)
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865

if channel_names is not None:
Warning(
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
else:
if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

if isinstance(image, xarray.DataArray):
# if so first validate the model since this means we are getting the image from a spatialdata object already
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
Image2DModel().validate(image)

if channel_names is not None:
Warning(
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
)

image = Image2DModel.parse(
image,
scale_factors=scale_factors,
rgb=False,
)

image = Image2DModel.parse(
image,
scale_factors=scale_factors,
rgb=False,
)

else:
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)
else:
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)

if overwrite:
self._force_delete_object(_sdata, image_name, "images")
Expand All @@ -249,7 +240,7 @@ def _write_image_sdata(

def _write_segmentation_object_sdata(
self,
segmentation_object: spLabels2DModel,
segmentation_object: Labels2DModel,
segmentation_label: str,
classes: set[str] | None = None,
overwrite: bool = False,
Expand All @@ -264,9 +255,6 @@ def _write_segmentation_object_sdata(
"""
_sdata = self._read_sdata()

if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=classes)

if overwrite:
self._force_delete_object(_sdata, segmentation_label, "labels")

Expand Down Expand Up @@ -294,7 +282,7 @@ def _write_segmentation_sdata(
overwrite: Whether to overwrite existing data
"""
transform_original = Identity()
mask = spLabels2DModel.parse(
mask = Labels2DModel.parse(
segmentation,
dims=["y", "x"],
transformations={"global": transform_original},
Expand Down Expand Up @@ -324,6 +312,48 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over

self.log(f"Points {points_name} written to sdata object.")

def _write_table_sdata(
self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False
) -> None:
"""Write anndata to SpatialData.
Args:
adata: AnnData object to write
table_name: Name for the table object under which it should be saved
segmentation_mask_name: Name of the segmentation mask that this table annotates
overwrite: Whether to overwrite existing data
Returns:
None (writes to sdata object)
"""
_sdata = self._read_sdata()

assert isinstance(adata, AnnData), "Input data must be an AnnData object."
assert segmentation_mask_name in _sdata.labels, "Segmentation mask not found in sdata object."

# get obs and obs_indices
obs = adata.obs
obs_indices = adata.obs.index.astype(int) # need to ensure int subtype to be able to annotate seg masks

# sanity checking
assert len(obs_indices) == len(set(obs_indices)), "Instance IDs are not unique."
cell_ids_mask = set(_sdata[f"{self.centers_name}_{segmentation_mask_name}"].index.values.compute())
assert set(obs_indices).issubset(cell_ids_mask), "Instance IDs do not match segmentation mask cell IDs."

obs["instance_id"] = obs_indices
obs["region"] = segmentation_mask_name
obs["region"] = obs["region"].astype("category")

adata.obs = obs
table = TableModel.parse(
adata,
region=[segmentation_mask_name],
region_key="region",
instance_key="instance_id",
)

self._write_table_object_sdata(table, table_name, overwrite=overwrite)

def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrite: bool = False) -> None:
"""Write table object to SpatialData.
Expand Down Expand Up @@ -358,7 +388,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod
if segmentation_label not in sdata.labels:
raise ValueError(f"Segmentation {segmentation_label} not found in sdata object.")

centers = calculate_centroids(sdata.labels[segmentation_label])
mask = sdata.labels[segmentation_label]
if isinstance(mask, xarray.DataTree):
mask = mask.scale0.image
centers = calculate_centroids(mask)
return centers

def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None:
Expand All @@ -370,7 +403,8 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
"""
_sdata = self._read_sdata()
centroids_object = self._get_centers(_sdata, segmentation_label)
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)
centers_name = f"{self.centers_name}_{segmentation_label}"
self._write_points_object_sdata(centroids_object, centers_name, overwrite=overwrite)

## load elements from sdata to a memory mapped array
def _load_input_image_to_memmap(
Expand Down
108 changes: 0 additions & 108 deletions src/scportrait/pipeline/_utils/spatialdata_classes.py

This file was deleted.

15 changes: 6 additions & 9 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ def _get_segmentation_info(self):
# this mask will be used to calculate the cell centers
if self.n_masks == 2:
# perform sanity check that the masks have the same ids
assert (
_sdata[self.nucleus_key].attrs["cell_ids"] == _sdata[self.cytosol_key].attrs["cell_ids"]
), "Nucleus and cytosol masks contain different cell ids. Cannot proceed with extraction."
# THIS NEEDS TO BE IMPLEMENTED HERE

self.main_segmenation_mask = self.nucleus_key

Expand All @@ -352,27 +350,26 @@ def _get_centers(self):
_sdata = self.filehandler._read_sdata()

# calculate centers if they have not been calculated yet
if self.DEFAULT_CENTERS_NAME not in _sdata:
centers_name = f"{self.DEFAULT_CENTERS_NAME}_{self.main_segmenation_mask}"
if centers_name not in _sdata:
self.filehandler._add_centers(self.main_segmenation_mask, overwrite=self.overwrite)
_sdata = self.filehandler._read_sdata() # reread to ensure we have updated version

centers = _sdata[self.DEFAULT_CENTERS_NAME].values.compute()
centers = _sdata[centers_name].values.compute()

# round to int so that we can use them as indices
centers = np.round(centers).astype(int)

self.centers = centers
self.centers_cell_ids = _sdata[self.DEFAULT_CENTERS_NAME].index.values.compute()
self.centers_cell_ids = _sdata[centers_name].index.values.compute()

# ensure that the centers ids are unique
assert len(self.centers_cell_ids) == len(
set(self.centers_cell_ids)
), "Cell ids in centers are not unique. Cannot proceed with extraction."

# double check that the cell_ids contained in the seg masks match to those from centers
assert set(self.centers_cell_ids) == set(
_sdata[self.main_segmenation_mask].attrs["cell_ids"]
), "Cell ids from centers do not match those from the segmentation mask. Cannot proceed with extraction."
# THIS NEEDS TO BE IMPLEMENTED HERE

def _get_classes_to_extract(self):
if self.partial_processing:
Expand Down
Loading

0 comments on commit b717064

Please sign in to comment.