Skip to content

Commit

Permalink
simplify _from_sdata method to rely on methods implement in spatialdata
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 11, 2025
1 parent 907bf19 commit 664353b
Showing 1 changed file with 32 additions and 164 deletions.
196 changes: 32 additions & 164 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
rechunk_image,
remap_region_annotation_table,
)
from scportrait.spdata.write._helper import _get_shape, _make_key_lookup

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -875,186 +876,53 @@ def load_input_from_sdata(

# read input sdata object
sdata_input = SpatialData.read(sdata_path)
if keep_all:
shutil.rmtree(self.sdata_path, ignore_errors=True) # remove old sdata object
sdata_input.write(self.sdata_path, overwrite=True)
del sdata_input
sdata_input = self.filehandler.get_sdata()

self.get_project_status()

# get input image and write it to the final sdata object
image = sdata_input.images[input_image_name]
self.log(f"Adding image {input_image_name} to sdata object as 'input_image'.")

if isinstance(image, xarray.DataTree):
image_c, image_x, image_y = image.scale0.image.shape

# ensure chunking is correct
if rechunk:
for scale in image:
self._check_chunk_size(image[scale].image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D)
all_elements = [x.split("/")[1] for x in sdata_input.elements_paths_in_memory()]

# get channel names
channel_names = image.scale0.image.c.values
dict_elems = {self.DEFAULT_INPUT_IMAGE_NAME: sdata_input[input_image_name]}

elif isinstance(image, xarray.DataArray):
image_c, image_x, image_y = image.shape

# ensure chunking is correct
if rechunk:
self._check_chunk_size(image, chunk_size=self.DEFAULT_CHUNK_SIZE_3D)

channel_names = image.c.values

# Reset all transformations
if image.attrs.get("transform"):
self.log("Image contained transformations which which were removed.")
image.attrs["transform"] = None

# check coordinate system of input image
### PLACEHOLDER

# check channel names
self.log(
f"Found the following channel names in the input image and saving in the spatialdata object: {channel_names}"
)

self.filehandler._write_image_sdata(image, self.DEFAULT_INPUT_IMAGE_NAME, channel_names=channel_names)

# check if a nucleus segmentation exists and if so add it to the sdata object
if nucleus_segmentation_name is not None:
mask = sdata_input.labels[nucleus_segmentation_name]
self.log(
f"Adding nucleus segmentation mask '{nucleus_segmentation_name}' to sdata object as '{self.nuc_seg_name}'."
)

# if mask is multi-scale ensure we only use the scale 0
if isinstance(mask, xarray.DataTree):
mask = mask["scale0"].image

# ensure that loaded masks are at the same scale as the input image
mask_x, mask_y = mask.shape
assert (mask_x == image_x) and (
mask_y == image_y
), "Nucleus segmentation mask does not match input image size."

if rechunk:
self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct

self.filehandler._write_segmentation_object_sdata(mask, self.nuc_seg_name)
self.log(
f"Calculating centers for nucleus segmentation mask {self.nuc_seg_name} and adding to spatialdata object."
)
self.filehandler._add_centers(segmentation_label=self.nuc_seg_name)
dict_elems[self.nuc_seg_name] = sdata_input[nucleus_segmentation_name]
if remove_duplicates:
all_elements.remove(nucleus_segmentation_name)

# check if a cytosol segmentation exists and if so add it to the sdata object
if cytosol_segmentation_name is not None:
mask = sdata_input.labels[cytosol_segmentation_name]
self.log(
f"Adding cytosol segmentation mask '{cytosol_segmentation_name}' to sdata object as '{self.cyto_seg_name}'."
)

# if mask is multi-scale ensure we only use the scale 0
if isinstance(mask, xarray.DataTree):
mask = mask["scale0"].image
dict_elems[self.cyto_seg_name] = sdata_input[cytosol_segmentation_name]
if remove_duplicates:
all_elements.remove(cytosol_segmentation_name)

# ensure that loaded masks are at the same scale as the input image
mask_x, mask_y = mask.shape
assert (mask_x == image_x) and (
mask_y == image_y
), "Nucleus segmentation mask does not match input image size."
if keep_all:
shutil.rmtree(self.sdata_path, ignore_errors=True)
for elem in all_elements:
dict_elems[elem] = sdata_input[elem]

if rechunk:
self._check_chunk_size(mask, chunk_size=self.DEFAULT_CHUNK_SIZE_2D) # ensure chunking is correct

self.filehandler._write_segmentation_object_sdata(mask, self.cyto_seg_name)
self.log(
f"Calculating centers for cytosol segmentation mask {self.nuc_seg_name} and adding to spatialdata object."
)
self.filehandler._add_centers(segmentation_label=self.cyto_seg_name)
sdata = SpatialData.from_elements_dict(dict_elems)
sdata.write(self.sdata_path, overwrite=True)

# update project status
self.get_project_status()
_, x, y = _get_shape(sdata[self.DEFAULT_INPUT_IMAGE_NAME])

# ensure that the provided nucleus and cytosol segmentations fullfill the scPortrait requirements
# requirements are:
# 1. The nucleus segmentation mask and the cytosol segmentation mask must contain the same ids
# if self.nuc_seg_status and self.cyto_seg_status:
# THIS NEEDS TO BE IMPLEMENTED HERE

# 2. the nucleus segmentation ids and the cytosol segmentation ids need to match
# THIS NEEDS TO BE IMPLEMENTED HERE

# check if there are any annotations that match the nucleus/cytosol segmentations
if self.nuc_seg_status or self.cyto_seg_status:
region_annotation = generate_region_annotation_lookuptable(self.sdata)

if self.nuc_seg_status:
region_name = self.nuc_seg_name

# add existing nucleus annotations if available
if nucleus_segmentation_name in region_annotation.keys():
for x in region_annotation[nucleus_segmentation_name]:
table_name, table = x

new_table_name = f"annot_{region_name}_{table_name}"
self.overwrite = original_overwrite

table = remap_region_annotation_table(table, region_name=region_name)
if self.nuc_seg_status:
# check input size
_, x_mask, y_mask = _get_shape(sdata[self.nuc_seg_name])
assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size."

self.filehandler._write_table_object_sdata(table, new_table_name)
self.log(
f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}."
)

if keep_all and remove_duplicates:
self.log(
f"Deleting original annotation {table_name} for nucleus segmentation {nucleus_segmentation_name} from sdata object to prevent information duplication."
)
self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables")
else:
self.log(f"No region annotation found for the nucleus segmentation {nucleus_segmentation_name}.")

# add cytosol segmentations if available
if self.cyto_seg_status:
if cytosol_segmentation_name in region_annotation.keys():
for x in region_annotation[cytosol_segmentation_name]:
table_name, table = x
region_name = self.cyto_seg_name
new_table_name = f"annot_{region_name}_{table_name}"

table = remap_region_annotation_table(table, region_name=region_name)
self.filehandler._write_table_object_sdata(table, new_table_name)

self.log(
f"Added annotation {new_table_name} to spatialdata object for segmentation object {region_name}."
)

if keep_all and remove_duplicates:
self.log(
f"Deleting original annotation {table_name} for cytosol segmentation {cytosol_segmentation_name} from sdata object to prevent information duplication."
)
self.filehandler._force_delete_object(self.sdata, name=table_name, type="tables")
else:
self.log(f"No region annotation found for the cytosol segmentation {cytosol_segmentation_name}.")
self.filehandler._add_centers(segmentation_label=self.nuc_seg_name)
if self.cyto_seg_status:
# check input size
_, x_mask, y_mask = _get_shape(sdata[self.cyto_seg_name])
assert x == x_mask and y == y_mask, "Input image and nucleus segmentation mask do not match in size."

if keep_all and remove_duplicates:
# remove input image
self.log(f"Deleting input image '{input_image_name}' from sdata object to prevent information duplication.")
self.filehandler._force_delete_object(self.sdata, name=input_image_name, type="images")
self.filehandler._add_centers(segmentation_label=self.cyto_seg_name)

if self.nuc_seg_status:
self.log(
f"Deleting original nucleus segmentation mask '{nucleus_segmentation_name}' from sdata object to prevent information duplication."
)
self.filehandler._force_delete_object(self.sdata, name=nucleus_segmentation_name, type="labels")
if self.cyto_seg_status:
self.log(
f"Deleting original cytosol segmentation mask '{cytosol_segmentation_name}' from sdata object to prevent information duplication."
)
self.filehandler._force_delete_object(self.sdata, name=cytosol_segmentation_name, type="labels")
if self.nuc_seg_status and self.cyto_seg_status:
ids_nuc = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.nuc_seg_name}"].index.values)
ids_cyto = set(sdata[f"{self.DEFAULT_CENTERS_NAME}_{self.cyto_seg_name}"].index.values)
assert ids_nuc == ids_cyto, "Nucleus and cytosol segmentation masks do not match."

self.get_project_status()
self.overwrite = original_overwrite # reset to original value

#### Functions to perform processing ####

Expand Down

0 comments on commit 664353b

Please sign in to comment.