Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve spatialdata file handling #171

Merged
merged 14 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class ProcessingStep(Logable):
DEFAULT_SEG_NAME_0 = "nucleus"
DEFAULT_SEG_NAME_1 = "cytosol"

DEFAULT_CENTERS_NAME = "centers_cells"
DEFAULT_CENTERS_NAME = "centers"

DEFAULT_CHUNK_SIZE = (1, 1000, 1000)

Expand Down
144 changes: 89 additions & 55 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import numpy as np
import xarray
from alphabase.io import tempmmap
from anndata import AnnData
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, PointsModel, TableModel
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel
from spatialdata.transformations.transformations import Identity

from scportrait.pipeline._base import Logable
from scportrait.pipeline._utils.spatialdata_classes import spLabels2DModel
from scportrait.pipeline._utils.spatialdata_helper import (
calculate_centroids,
get_chunk_size,
Expand Down Expand Up @@ -87,13 +87,6 @@ def _read_sdata(self) -> SpatialData:
_sdata = self._create_empty_sdata()
_sdata.write(self.sdata_path, overwrite=True)

allowed_labels = ["seg_all_nucleus", "seg_all_cytosol"]
for key in _sdata.labels:
if key in allowed_labels:
segmentation_object = _sdata.labels[key]
if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=None)

return _sdata

def get_sdata(self) -> SpatialData:
Expand Down Expand Up @@ -194,49 +187,47 @@ def _write_image_sdata(
# check if the image is already a multi-scale image
if isinstance(image, xarray.DataTree):
# if so only validate the model since this means we are getting the image from a spatialdata object already
# image = Image2DModel.validate(image)
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
Image2DModel().validate(image)
if scale_factors is not None:
Warning("Scale factors are ignored when passing a multi-scale image.")
image = image.scale0.image

if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

if isinstance(image, xarray.DataArray):
# if so first validate the model since this means we are getting the image from a spatialdata object already
# then apply the scales transform
# image = Image2DModel.validate(image)
# this appraoch is currently not functional but an issue was opened at https://github.com/scverse/spatialdata/issues/865

if channel_names is not None:
Warning(
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
else:
if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

if isinstance(image, xarray.DataArray):
# if so first validate the model since this means we are getting the image from a spatialdata object already
# fix until #https://github.com/scverse/spatialdata/issues/528 is resolved
Image2DModel().validate(image)

if channel_names is not None:
Warning(
"Channel names are ignored when passing a single scale image in the DataArray format. Channel names are read directly from the DataArray."
)

image = Image2DModel.parse(
image,
scale_factors=scale_factors,
rgb=False,
)

image = Image2DModel.parse(
image,
scale_factors=scale_factors,
rgb=False,
)

else:
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)
else:
if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)

if overwrite:
self._force_delete_object(_sdata, image_name, "images")
Expand All @@ -249,7 +240,7 @@ def _write_image_sdata(

def _write_segmentation_object_sdata(
self,
segmentation_object: spLabels2DModel,
segmentation_object: Labels2DModel,
segmentation_label: str,
classes: set[str] | None = None,
overwrite: bool = False,
Expand All @@ -264,9 +255,6 @@ def _write_segmentation_object_sdata(
"""
_sdata = self._read_sdata()

if not hasattr(segmentation_object.attrs, "cell_ids"):
segmentation_object = spLabels2DModel().convert(segmentation_object, classes=classes)

if overwrite:
self._force_delete_object(_sdata, segmentation_label, "labels")

Expand Down Expand Up @@ -294,7 +282,7 @@ def _write_segmentation_sdata(
overwrite: Whether to overwrite existing data
"""
transform_original = Identity()
mask = spLabels2DModel.parse(
mask = Labels2DModel.parse(
segmentation,
dims=["y", "x"],
transformations={"global": transform_original},
Expand Down Expand Up @@ -324,6 +312,48 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over

self.log(f"Points {points_name} written to sdata object.")

def _write_table_sdata(
self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False
) -> None:
"""Write anndata to SpatialData.

Args:
adata: AnnData object to write
table_name: Name for the table object under which it should be saved
segmentation_mask_name: Name of the segmentation mask that this table annotates
overwrite: Whether to overwrite existing data

Returns:
None (writes to sdata object)
"""
_sdata = self._read_sdata()

assert isinstance(adata, AnnData), "Input data must be an AnnData object."
assert segmentation_mask_name in _sdata.labels, "Segmentation mask not found in sdata object."

# get obs and obs_indices
obs = adata.obs
obs_indices = adata.obs.index.astype(int) # need to ensure int subtype to be able to annotate seg masks

# sanity checking
assert len(obs_indices) == len(set(obs_indices)), "Instance IDs are not unique."
cell_ids_mask = set(_sdata[f"{self.centers_name}_{segmentation_mask_name}"].index.values.compute())
assert set(obs_indices).issubset(cell_ids_mask), "Instance IDs do not match segmentation mask cell IDs."

obs["instance_id"] = obs_indices
obs["region"] = segmentation_mask_name
obs["region"] = obs["region"].astype("category")

adata.obs = obs
table = TableModel.parse(
adata,
region=[segmentation_mask_name],
region_key="region",
instance_key="instance_id",
)

self._write_table_object_sdata(table, table_name, overwrite=overwrite)

def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrite: bool = False) -> None:
"""Write table object to SpatialData.

Expand Down Expand Up @@ -358,7 +388,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod
if segmentation_label not in sdata.labels:
raise ValueError(f"Segmentation {segmentation_label} not found in sdata object.")

centers = calculate_centroids(sdata.labels[segmentation_label])
mask = sdata.labels[segmentation_label]
if isinstance(mask, xarray.DataTree):
mask = mask.scale0.image
centers = calculate_centroids(mask)
return centers

def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None:
Expand All @@ -370,7 +403,8 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
"""
_sdata = self._read_sdata()
centroids_object = self._get_centers(_sdata, segmentation_label)
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)
centers_name = f"{self.centers_name}_{segmentation_label}"
self._write_points_object_sdata(centroids_object, centers_name, overwrite=overwrite)

## load elements from sdata to a memory mapped array
def _load_input_image_to_memmap(
Expand Down
108 changes: 0 additions & 108 deletions src/scportrait/pipeline/_utils/spatialdata_classes.py

This file was deleted.

15 changes: 6 additions & 9 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,7 @@ def _get_segmentation_info(self):
# this mask will be used to calculate the cell centers
if self.n_masks == 2:
# perform sanity check that the masks have the same ids
assert (
_sdata[self.nucleus_key].attrs["cell_ids"] == _sdata[self.cytosol_key].attrs["cell_ids"]
), "Nucleus and cytosol masks contain different cell ids. Cannot proceed with extraction."
# THIS NEEDS TO BE IMPLEMENTED HERE

self.main_segmenation_mask = self.nucleus_key

Expand All @@ -352,27 +350,26 @@ def _get_centers(self):
_sdata = self.filehandler._read_sdata()

# calculate centers if they have not been calculated yet
if self.DEFAULT_CENTERS_NAME not in _sdata:
centers_name = f"{self.DEFAULT_CENTERS_NAME}_{self.main_segmenation_mask}"
if centers_name not in _sdata:
self.filehandler._add_centers(self.main_segmenation_mask, overwrite=self.overwrite)
_sdata = self.filehandler._read_sdata() # reread to ensure we have updated version

centers = _sdata[self.DEFAULT_CENTERS_NAME].values.compute()
centers = _sdata[centers_name].values.compute()

# round to int so that we can use them as indices
centers = np.round(centers).astype(int)

self.centers = centers
self.centers_cell_ids = _sdata[self.DEFAULT_CENTERS_NAME].index.values.compute()
self.centers_cell_ids = _sdata[centers_name].index.values.compute()

# ensure that the centers ids are unique
assert len(self.centers_cell_ids) == len(
set(self.centers_cell_ids)
), "Cell ids in centers are not unique. Cannot proceed with extraction."

# double check that the cell_ids contained in the seg masks match to those from centers
assert set(self.centers_cell_ids) == set(
_sdata[self.main_segmenation_mask].attrs["cell_ids"]
), "Cell ids from centers do not match those from the segmentation mask. Cannot proceed with extraction."
# THIS NEEDS TO BE IMPLEMENTED HERE

def _get_classes_to_extract(self):
if self.partial_processing:
Expand Down
Loading