Skip to content

Commit

Permalink
improve selection workflow
Browse files Browse the repository at this point in the history
- only calculate coords for cells that are to be selected instead of all cells
- significant speed increase especially for very large datasets
  • Loading branch information
sophiamaedler committed Dec 17, 2024
1 parent d6c0721 commit d3e1225
Showing 1 changed file with 150 additions and 38 deletions.
188 changes: 150 additions & 38 deletions src/scportrait/pipeline/selection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from scportrait.pipeline.base import ProcessingStep

import os
import numpy as np
import h5py
import pickle
from lmd.lib import SegmentationLoader
from alphabase.io import tempmmap
from lmd.segmentation import _create_coord_index_sparse
from tqdm.auto import tqdm
import timeit
import pandas as pd
from scipy.sparse import coo_array
from functools import partial as func_partial
import multiprocessing as mp
from scportrait.processing.utils import flatten
from scportrait.pipeline.base import ProcessingStep
from lmd.lib import SegmentationLoader
from pathlib import Path

class LMDSelection(ProcessingStep):
"""
Expand All @@ -22,12 +28,136 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

#check config for required parameters
assert "segmentation_channel" in self.config, "segmentation_channel not defined in config"
self._check_config()

def _check_config(self):
#check mandatory config parameters
assert "segmentation_channel" in self.config, "segmentation_channel not defined in config"
self.segmentation_channel_to_select = self.config["segmentation_channel"]

#the coord pickle file should be saved in the same directory as the segmentation results because it is based on that segmentation (if that segmentation is updated or changed it should also be recalculated)
self.coord_pickle_file_path = os.path.join(self.project_location, self.DEFAULT_SEGMENTATION_DIR_NAME, f"{self.segmentation_channel_to_select}_{self.COORD_PICKLE_FILE}")
# check for optional config parameters
if "cell_width" in self.config:
self.cell_width = self.config["cell_width"]
else:
self.cell_radius = 100

if "threads" in self.config:
self.threads = self.config["threads"]
assert self.threads > 0, "threads must be greater than 0"
assert isinstance(self.threads, int), "threads must be an integer"
else:
self.threads = 10

if "batch_size_coordinate_extraction" in self.config:
self.batch_size = self.config["batch_size_coordinate_extraction"]
assert self.batch_size > 0, "batch_size_coordinate_extraction must be greater than 0"
assert isinstance(self.batch_size, int), "batch_size_coordinate_extraction must be an integer"
else:
self.batch_size = 100

def __get_coords(self,
cell_ids: list,
centers:list[tuple[int, int]],
hdf_path:str,
segmentation_channel:int,
width:int = 60) -> list[tuple[int, np.ndarray]]:
results = []

with h5py.File(hdf_path, "r") as hf:
hdf_labels = hf["labels"]
for i, _id in enumerate(cell_ids):
values = centers[i]

x_start = np.max([int(values[0]) - width, 0])
y_start = np.max([int(values[1]) - width, 0])

x_end = x_start + width*2
y_end = y_start + width*2

_cropped = hdf_labels[segmentation_channel, slice(x_start, x_end), slice(y_start, y_end)]
sparse = coo_array(_cropped == _id)

x = sparse.coords[0] + x_start
y = sparse.coords[1] + y_start

results.append((_id, np.array(list(zip(x, y)))))
return(results)

def _get_coords_multi(self, hdf_path:str, segmentation_channel:int, width:int, arg: tuple[list[int], np.ndarray]) -> list[tuple[int, np.ndarray]]:
cell_ids, centers = arg
results = self.__get_coords(cell_ids, centers, hdf_path, segmentation_channel, width)
return(results)

def _get_coords(self,
cell_ids: list,
centers:list[tuple[int, int]],
hdf_path:str,
segmentation_channel:int,
width:int = 60,
batch_size:int = 100,
threads:int = 10) -> dict:

#create batches
n_batches = int(np.ceil(len(cell_ids)/batch_size))
slices = [(i*batch_size, i*batch_size + batch_size) for i in range(n_batches - 1)]
slices.append(((n_batches - 1)*batch_size, len(cell_ids)))

batched_args = [(cell_ids[start:end], centers[start:end]) for start, end in slices]

f = func_partial(self._get_coords_multi,
hdf_path,
segmentation_channel,
width
)

if threads == 1: # if only one thread is used, the function is called directly to avoid the overhead of multiprocessing
results = [f(arg) for arg in batched_args]
else:
with mp.get_context(self.context).Pool(processes=threads) as pool:
results = list(tqdm(
pool.imap(f, batched_args),
total=len(batched_args),
desc="Processing cell batches",
)
)
pool.close()
pool.join()

results = flatten(results)
return(dict(results))

def _get_cell_ids(self, cell_sets: list[dict]) -> list[int]:
cell_ids = []
for cell_set in cell_sets:
if "classes" in cell_set:
cell_ids.extend(cell_set["classes"])
else:
Warning(f"Cell set {cell_set['name']} does not contain any classes.")
return(cell_ids)

def _get_centers(self, cell_ids: list[int]) -> list[tuple[int, int]]:
centers_path = Path(self.project_location) / "extraction" / "center.pickle"
_ids_path = Path(self.project_location) / "extraction" / "_cell_ids.pickle"

if centers_path.exists() and _ids_path.exists():
with open(centers_path, "rb") as f:
centers = pickle.load(f)
with open(_ids_path, "rb") as f:
_ids = pickle.load(f)
else:
raise ValueError("Center and cell id files not found.")

centers = pd.DataFrame(centers, columns=["x", "y"])

#convert coordinates to integers for compatibility with indexing in segmentation mask
centers.x = centers.x.astype(int)
centers.y = centers.y.astype(int)
centers["cell_id"] = _ids
centers.set_index("cell_id", inplace=True)

centers = centers.loc[cell_ids, :]

return(centers[["x", "y"]].values.tolist())

def process(self, hdf_location, cell_sets, calibration_marker, name=None):
"""
Expand Down Expand Up @@ -135,35 +265,19 @@ class will automatically provide the most recent segmentation together with the
self.log("Selection process started")

#calculate a coordinate lookup file where for each cell id the coordinates for their location in the segmentation mask are stored
if os.path.exists(self.coord_pickle_file_path):
self.log(f"Loading coordinate lookup index from file {self.coord_pickle_file_path}.")
with open(self.coord_pickle_file_path, "rb") as f:
coord_index = pickle.load(f)
segmentation = None
else:
self.log("Calculating coordinate lookup index.")

#start timer for performance evaluation
start_time = timeit.default_timer()
self.log("Calculating coordinate lookup index for the specified cell ids.")
start_time = timeit.default_timer()
cell_ids = self._get_cell_ids(cell_sets)
centers = self._get_centers(cell_ids)
coord_index = self._get_coords(cell_ids = cell_ids,
centers = centers,
hdf_path = hdf_location,
segmentation_channel = self.segmentation_channel_to_select,
width = self.cell_radius,
batch_size = self.batch_size,
threads = self.threads)
self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.")

# load segmentation from hdf5
with h5py.File(hdf_location, "r") as hf:
hdf_labels = hf.get("labels")

# create memory mapped temporary array for saving the segmentation
c, x, y = hdf_labels.shape
segmentation = tempmmap.array(
shape=(x, y), dtype=hdf_labels.dtype, tmp_dir_abs_path=self._tmp_dir_path
)
segmentation[:] = hdf_labels[self.config["segmentation_channel"], :, :]

coord_index = dict(_create_coord_index_sparse(segmentation))

with open(self.coord_pickle_file_path, "wb") as f:
pickle.dump(coord_index, f)
self.log(f"Coordinate lookup index saved to file {self.coord_pickle_file_path}.")
self.log(f"Coordinate lookup index calculation took {timeit.default_timer() - start_time} seconds.")

#add default orientation transform
self.config["orientation_transform"] = np.array([[0, -1], [1, 0]])

Expand All @@ -173,7 +287,7 @@ class will automatically provide the most recent segmentation together with the
processes=self.config["processes_cell_sets"],
)

shape_collection = sl(segmentation, cell_sets, calibration_marker, coords_lookup=coord_index)
shape_collection = sl(None, cell_sets, calibration_marker, coords_lookup=coord_index)

if self.debug:
shape_collection.plot(calibration=True)
Expand All @@ -189,6 +303,4 @@ class will automatically provide the most recent segmentation together with the
savepath = os.path.join(self.directory, savename)
shape_collection.save(savepath)

del segmentation

self.log(f"Saved output at {savepath}")

0 comments on commit d3e1225

Please sign in to comment.