From 6299cc4410ebcbe4bf42dc26cfca38c5a3b64244 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:00:40 +0100 Subject: [PATCH] [FIX] add support for calculating centers on multi-scaled segmentation masks --- src/scportrait/pipeline/_utils/sdata_io.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/scportrait/pipeline/_utils/sdata_io.py b/src/scportrait/pipeline/_utils/sdata_io.py index 479538e1..df2185aa 100644 --- a/src/scportrait/pipeline/_utils/sdata_io.py +++ b/src/scportrait/pipeline/_utils/sdata_io.py @@ -347,7 +347,10 @@ def _get_centers(self, sdata: SpatialData, segmentation_label: str) -> PointsMod if segmentation_label not in sdata.labels: raise ValueError(f"Segmentation {segmentation_label} not found in sdata object.") - centers = calculate_centroids(sdata.labels[segmentation_label]) + mask = sdata.labels[segmentation_label] + if isinstance(mask, xarray.DataTree): + mask = mask.scale0.image + centers = calculate_centroids(mask) return centers def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None: