Skip to content

Commit

Permalink
add possibility to turn of masking during single-cell image extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Feb 12, 2025
1 parent c81cd39 commit 72aa76b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class ProcessingStep(Logable):
# extraction
DEFAULT_EXTRACTION_DIR_NAME = "extraction"
DEFAULT_DATA_DIR = "data"
DEFAULT_DATA_FILE = "single_cells.h5"
DEFAULT_EXTRACTION_FILE = "single_cells.h5"
DEFAULT_EXTRACTION_FILE_NO_MASK = "cropped_rois.h5"

# classification
DEFAULT_CLASSIFICATION_DIR_NAME = "classification"
Expand Down
2 changes: 1 addition & 1 deletion src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ class based on the previous single-cell extraction. Therefore, no parameters nee

# generate dataloader
dataloader = self.generate_dataloader(
f"{extraction_dir}/{self.DEFAULT_DATA_FILE}"
f"{extraction_dir}/{self.DEFAULT_EXTRACTION_FILE}"
)

# perform inference
Expand Down
32 changes: 24 additions & 8 deletions src/scportrait/pipeline/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ def _verbalise_extraction_info(self):
self.log(
f"Extracted Image Dimensions: {self.extracted_image_size} x {self.extracted_image_size}"
)
self.log(
f"Resulting images will be masked to include information from single-cells only: {self.mask_images}"
)

def _generate_save_index_lookup(self, class_list):
self.save_index_lookup = pd.DataFrame(index=class_list)
Expand Down Expand Up @@ -613,9 +616,15 @@ def _transfer_tempmmap_to_hdf5(self):
self.log("Transferring extracted single cells to .hdf5")

# create name for output file
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_EXTRACTION_FILE
)

if self.mask_images:
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_EXTRACTION_FILE
)
else:
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_EXTRACTION_FILE_NO_MASK
)

with h5py.File(self.output_path, "w") as hf:
hf.create_dataset(
Expand Down Expand Up @@ -761,7 +770,6 @@ def _extract_classes(
slice(y // 2 - 3, y // 2 + 3),
]
print("center of nucleus array \n", center_nuclei, "\n")

mask = np.where(mask == ids[mask_ix], 1, 0).astype(int)
mask = binary_fill_holes(mask)
mask = gaussian(mask, preserve_range=True, sigma=1)
Expand All @@ -776,7 +784,8 @@ def _extract_classes(
# image_data = self.input_image[image_index, :, window_y, window_x].compute()
channel = hdf_channels[image_index, i, window_y, window_x]

channel = channel * masks[-1]
if self.mask_images:
channel = channel * masks[-1]
channel = self.norm_function(channel)

image_data.append(channel)
Expand Down Expand Up @@ -914,6 +923,7 @@ def process(
partial=False,
n_cells=None,
seed=42,
mask_single_cell_images=True,
):
"""
Extracts single cell images from a segmented SPARCSpy project and saves the results to an HDF5 file.
Expand Down Expand Up @@ -963,6 +973,7 @@ def process(
"""

total_time_start = timeit.default_timer()
self.mask_images = mask_single_cell_images

# run all of the extraction setup steps
start_setup = timeit.default_timer()
Expand Down Expand Up @@ -1183,9 +1194,14 @@ def _transfer_tempmmap_to_hdf5(self):
self.log("Creating HDF5 file to save results to.")

# define output path
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_DATA_FILE
)
if self.mask_images:
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_EXTRACTION_FILE
)
else:
self.output_path = os.path.join(
self.extraction_data_directory, self.DEFAULT_EXTRACTION_FILE_NO_MASK
)

with h5py.File(self.output_path, "w") as hf:
self.log(
Expand Down
5 changes: 3 additions & 2 deletions src/scportrait/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class Project(Logable):
# extraction
DEFAULT_EXTRACTION_DIR_NAME = "extraction"
DEFAULT_DATA_DIR = "data"
DEFAULT_DATA_FILE = "single_cells.h5"
DEFAULT_EXTRACTION_FILE_NAME = "single_cells.h5"
DEFAULT_EXTRACTION_FILE_NAME_NO_MASKS = "cropped_rois.h5"

# classification
DEFAULT_CLASSIFICATION_DIR_NAME = "classification"
Expand Down Expand Up @@ -846,7 +847,7 @@ def classify(self, partial=False, *args, **kwargs):
raise ValueError("input was not found at {}".format(input_extraction))

self.classification_f(
f"{input_extraction}/{self.DEFAULT_DATA_FILE}",
f"{input_extraction}/{self.DEFAULT_EXTRACTION_FILE_NAME}",
partial=partial,
*args,
**kwargs,
Expand Down

0 comments on commit 72aa76b

Please sign in to comment.