Skip to content

Commit

Permalink
deprecate _force_delete_object in favour of new spdata function
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 11, 2025
1 parent ddf0cde commit 206c5d2
Showing 1 changed file with 11 additions and 58 deletions.
69 changes: 11 additions & 58 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
calculate_centroids,
get_chunk_size,
)
from scportrait.spdata.write._helper import add_element_sdata

ChunkSize2D: TypeAlias = tuple[int, int]
ChunkSize3D: TypeAlias = tuple[int, int, int]
Expand Down Expand Up @@ -97,21 +98,6 @@ def get_sdata(self) -> SpatialData:
"""
return self._read_sdata()

def _force_delete_object(self, sdata: SpatialData, name: str, type: ObjectType) -> None:
"""Force delete an object from the SpatialData object and directory.
Args:
sdata: SpatialData object
name: Name of object to delete
type: Type of object ("images", "labels", "points", "tables")
"""
if name in sdata:
del sdata[name]

path = os.path.join(self.sdata_path, type, name)
if os.path.exists(path):
shutil.rmtree(path, ignore_errors=True)

def _check_sdata_status(self, return_sdata: bool = False) -> SpatialData | None:
"""Check status of SpatialData objects.
Expand Down Expand Up @@ -229,12 +215,7 @@ def _write_image_sdata(
rgb=False,
)

if overwrite:
self._force_delete_object(_sdata, image_name, "images")

_sdata.images[image_name] = image
_sdata.write_element(image_name, overwrite=True)

add_element_sdata(_sdata, image, image_name, overwrite=overwrite)
self.log(f"Image {image_name} written to sdata object.")
self._check_sdata_status()

Expand All @@ -253,13 +234,7 @@ def _write_segmentation_object_sdata(
overwrite: Whether to overwrite existing data
"""
_sdata = self._read_sdata()

if overwrite:
self._force_delete_object(_sdata, segmentation_label, "labels")

_sdata.labels[segmentation_label] = segmentation_object
_sdata.write_element(segmentation_label, overwrite=True)

add_element_sdata(_sdata, segmentation_object, segmentation_label, overwrite=overwrite)
self.log(f"Segmentation {segmentation_label} written to sdata object.")
self._check_sdata_status()

Expand All @@ -281,10 +256,7 @@ def _write_segmentation_sdata(
"""
transform_original = Identity()
mask = Labels2DModel.parse(
segmentation,
dims=["y", "x"],
transformations={"global": transform_original},
chunks=chunks,
segmentation, dims=["y", "x"], transformations={"global": transform_original}, chunks=chunks
)

if not get_chunk_size(mask) == chunks:
Expand All @@ -301,14 +273,9 @@ def _write_points_object_sdata(self, points: PointsModel, points_name: str, over
overwrite: Whether to overwrite existing data
"""
_sdata = self._read_sdata()

if overwrite:
self._force_delete_object(_sdata, points_name, "points")

_sdata.points[points_name] = points
_sdata.write_element(points_name, overwrite=True)

add_element_sdata(_sdata, points, points_name, overwrite=overwrite)
self.log(f"Points {points_name} written to sdata object.")
self._check_sdata_status()

def _write_table_sdata(
self, adata: AnnData, table_name: str, segmentation_mask_name: str, overwrite: bool = False
Expand Down Expand Up @@ -344,10 +311,7 @@ def _write_table_sdata(

adata.obs = obs
table = TableModel.parse(
adata,
region=[segmentation_mask_name],
region_key="region",
instance_key="instance_id",
adata, region=[segmentation_mask_name], region_key="region", instance_key="instance_id"
)

self._write_table_object_sdata(table, table_name, overwrite=overwrite)
Expand All @@ -361,14 +325,9 @@ def _write_table_object_sdata(self, table: TableModel, table_name: str, overwrit
overwrite: Whether to overwrite existing data
"""
_sdata = self._read_sdata()

if overwrite:
self._force_delete_object(_sdata, table_name, "tables")

_sdata.tables[table_name] = table
_sdata.write_element(table_name, overwrite=True)

add_element_sdata(_sdata, table, table_name, overwrite=overwrite)
self.log(f"Table {table_name} written to sdata object.")
self._check_sdata_status()

def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsModel:
"""Get cell centers from segmentation.
Expand Down Expand Up @@ -433,11 +392,7 @@ def _load_input_image_to_memmap(
shape = image.shape

# initialize empty memory mapped arrays to store the data
path_input_image = tempmmap.create_empty_mmap(
shape=shape,
dtype=image.dtype,
tmp_dir_abs_path=tmp_dir_abs_path,
)
path_input_image = tempmmap.create_empty_mmap(shape=shape, dtype=image.dtype, tmp_dir_abs_path=tmp_dir_abs_path)

input_image_mmap = tempmmap.mmap_array_from_path(path_input_image)

Expand Down Expand Up @@ -521,9 +476,7 @@ def _load_seg_to_memmap(

# initialize empty memory mapped arrays to store the data
path_seg_masks = tempmmap.create_empty_mmap(
shape=shape,
dtype=seg_objects[0].data.dtype,
tmp_dir_abs_path=tmp_dir_abs_path,
shape=shape, dtype=seg_objects[0].data.dtype, tmp_dir_abs_path=tmp_dir_abs_path
)

seg_masks = tempmmap.mmap_array_from_path(path_seg_masks)
Expand Down

0 comments on commit 206c5d2

Please sign in to comment.