Skip to content

Commit

Permalink
Merge pull request #134 from MannLabs/improve_docs_docstrings
Browse files Browse the repository at this point in the history
Improve docs docstrings
  • Loading branch information
sophiamaedler authored Dec 13, 2024
2 parents 9d56ea9 + 8023b1c commit a94b8cb
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 254 deletions.
60 changes: 57 additions & 3 deletions src/scportrait/pipeline/_utils/sdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray
from alphabase.io import tempmmap
from spatialdata import SpatialData
from spatialdata.models import PointsModel, TableModel
from spatialdata.models import Image2DModel, PointsModel, TableModel
from spatialdata.transformations.transformations import Identity

from scportrait.pipeline._base import Logable
Expand All @@ -19,7 +19,8 @@
get_chunk_size,
)

ChunkSize: TypeAlias = tuple[int, int]
ChunkSize2D: TypeAlias = tuple[int, int]
ChunkSize3D: TypeAlias = tuple[int, int, int]
ObjectType: TypeAlias = Literal["images", "labels", "points", "tables"]


Expand Down Expand Up @@ -143,6 +144,58 @@ def _get_input_image(self, sdata: SpatialData) -> xarray.DataArray:
return input_image

## write elements to sdata object
def _write_image_sdata(
self,
image,
image_name: str,
channel_names: list[str] = None,
scale_factors: list[int] = None,
chunks: ChunkSize3D = (1, 1000, 1000),
overwrite=False,
):
"""
Write the supplied image to the spatialdata object.
Args:
image (dask.array): Image to be written to the spatialdata object.
image_name (str): Name of the image to be written to the spatialdata object.
channel_names list[str]: List of channel names for the image. Default is None.
scale_factors list[int]: List of scale factors for the image. Default is [2, 4, 8]. This will load the image at 4 different resolutions to allow for fluid visualization.
chunks (tuple): Chunk size for the image. Default is (1, 1000, 1000).
overwrite (bool): Whether to overwrite existing data. Default is False.
"""

if scale_factors is None:
scale_factors = [2, 4, 8]
if scale_factors is None:
scale_factors = [2, 4, 8]

_sdata = self._read_sdata()

if channel_names is None:
channel_names = [f"channel_{i}" for i in range(image.shape[0])]

# transform to spatialdata image model
transform_original = Identity()
image = Image2DModel.parse(
image,
dims=["c", "y", "x"],
chunks=chunks,
c_coords=channel_names,
scale_factors=scale_factors,
transformations={"global": transform_original},
rgb=False,
)

if overwrite:
self._force_delete_object(_sdata, image_name, "images")

_sdata.images[image_name] = image
_sdata.write_element(image_name, overwrite=True)

self.log(f"Image {image_name} written to sdata object.")
self._check_sdata_status()

def _write_segmentation_object_sdata(
self,
segmentation_object: spLabels2DModel,
Expand Down Expand Up @@ -177,7 +230,7 @@ def _write_segmentation_sdata(
segmentation: xarray.DataArray | np.ndarray,
segmentation_label: str,
classes: set[str] | None = None,
chunks: ChunkSize = (1000, 1000),
chunks: ChunkSize2D = (1000, 1000),
overwrite: bool = False,
) -> None:
"""Write segmentation data to SpatialData.
Expand Down Expand Up @@ -268,6 +321,7 @@ def _add_centers(self, segmentation_label: str, overwrite: bool = False) -> None
centroids_object = self._get_centers(_sdata, segmentation_label)
self._write_points_object_sdata(centroids_object, self.centers_name, overwrite=overwrite)

## load elements from sdata to a memory mapped array
def _load_input_image_to_memmap(
self, tmp_dir_abs_path: str | Path, image: np.typing.NDArray[Any] | None = None
) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,10 @@ def process(self, partial=False, n_cells=None, seed=42):
self.log("Loading input images to memory mapped arrays...")
start_data_transfer = timeit.default_timer()

self.path_seg_masks = self.project._load_seg_to_memmap(seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path)
self.path_image_data = self.project._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path)
self.path_seg_masks = self.filehandler._load_seg_to_memmap(
seg_name=self.masks, tmp_dir_abs_path=self._tmp_dir_path
)
self.path_image_data = self.filehandler._load_input_image_to_memmap(tmp_dir_abs_path=self._tmp_dir_path)

stop_data_transfer = timeit.default_timer()
time_data_transfer = stop_data_transfer - start_data_transfer
Expand Down
Loading

0 comments on commit a94b8cb

Please sign in to comment.