diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index 7c614d4e..c28ea557 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -86,11 +86,69 @@ def main( callback: Optional[Callable[[int], None]] = None, ) -> List[Cell]: """ + Perform cell candidate detection on a 3D signal array. + Parameters ---------- + signal_array : numpy.ndarray + 3D array representing the signal data. + + start_plane : int + Index of the starting plane for detection. + + end_plane : int + Index of the ending plane for detection. + + voxel_sizes : Tuple[float, float, float] + Tuple of voxel sizes in each dimension (x, y, z). + + soma_diameter : float + Diameter of the soma in physical units. + + max_cluster_size : float + Maximum size of a cluster in physical units. + + ball_xy_size : float + Size of the XY ball used for filtering in physical units. + + ball_z_size : float + Size of the Z ball used for filtering in physical units. + + ball_overlap_fraction : float + Fraction of overlap allowed between balls. + + soma_spread_factor : float + Spread factor for soma size. + + n_free_cpus : int + Number of free CPU cores available for parallel processing. + + log_sigma_size : float + Size of the sigma for the log filter. + + n_sds_above_mean_thresh : float + Number of standard deviations above the mean threshold. + + outlier_keep : bool, optional + Whether to keep outliers during detection. Defaults to False. + + artifact_keep : bool, optional + Whether to keep artifacts during detection. Defaults to False. + + save_planes : bool, optional + Whether to save the planes during detection. Defaults to False. + + plane_directory : str, optional + Directory path to save the planes. Defaults to None. + callback : Callable[int], optional A callback function that is called every time a plane has finished being processed. Called with the plane number that has finished. + + Returns + ------- + List[Cell] + List of detected cells. """ if not np.issubdtype(signal_array.dtype, np.integer): raise ValueError( @@ -117,6 +175,7 @@ def main( if end_plane == -1: end_plane = len(signal_array) signal_array = signal_array[start_plane:end_plane] + signal_array = signal_array.astype(np.uint32) callback = callback or (lambda *args, **kwargs: None) diff --git a/cellfinder/core/detect/filters/plane/classical_filter.py b/cellfinder/core/detect/filters/plane/classical_filter.py index d72cf615..af331d52 100644 --- a/cellfinder/core/detect/filters/plane/classical_filter.py +++ b/cellfinder/core/detect/filters/plane/classical_filter.py @@ -6,6 +6,31 @@ def enhance_peaks( img: np.ndarray, clipping_value: float, gaussian_sigma: float = 2.5 ) -> np.ndarray: + """ + Enhances the peaks (bright pixels) in an input image. + + Parameters: + ---------- + img : np.ndarray + Input image. + clipping_value : float + Maximum value for the enhanced image. + gaussian_sigma : float, optional + Standard deviation for the Gaussian filter. Default is 2.5. + + Returns: + ------- + np.ndarray + Enhanced image with peaks. + + Notes: + ------ + The enhancement process includes the following steps: + 1. Applying a 2D median filter. + 2. Applying a Laplacian of Gaussian filter (LoG). + 3. Multiplying by -1 (bright spots respond negative in a LoG). + 4. Rescaling image values to range from 0 to clipping value. + """ type_in = img.dtype filtered_img = medfilt2d(img.astype(np.float64)) filtered_img = gaussian_filter(filtered_img, gaussian_sigma) diff --git a/cellfinder/core/detect/filters/volume/ball_filter.py b/cellfinder/core/detect/filters/volume/ball_filter.py index 13aed04c..ebd3642e 100644 --- a/cellfinder/core/detect/filters/volume/ball_filter.py +++ b/cellfinder/core/detect/filters/volume/ball_filter.py @@ -104,7 +104,7 @@ def __init__( # Stores the current planes that are being filtered self.volume = np.empty( - (plane_width, plane_height, ball_z_size), dtype=np.uint16 + (plane_width, plane_height, ball_z_size), dtype=np.uint32 ) # Index of the middle plane in the volume self.middle_z_idx = int(np.floor(ball_z_size / 2)) @@ -165,7 +165,7 @@ def get_middle_plane(self) -> np.ndarray: Get the plane in the middle of self.volume. """ z = self.middle_z_idx - return np.array(self.volume[:, :, z], dtype=np.uint16) + return np.array(self.volume[:, :, z], dtype=np.uint32) def walk(self) -> None: # Highly optimised because most time critical ball_radius = self.ball_xy_size // 2 diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index 4240291a..0573e615 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -28,7 +28,7 @@ def coords_to_volume( expanded_shape = [ dim_size + ball_diameter for dim_size in get_shape(xs, ys, zs) ] - volume = np.zeros(expanded_shape, dtype=np.uint16) + volume = np.zeros(expanded_shape, dtype=np.uint32) x_min, y_min, z_min = xs.min(), ys.min(), zs.min() @@ -38,7 +38,7 @@ def coords_to_volume( # OPTIMISE: vectorize for rel_x, rel_y, rel_z in zip(relative_xs, relative_ys, relative_zs): - volume[rel_x, rel_y, rel_z] = 65534 + volume[rel_x, rel_y, rel_z] = np.iinfo(volume.dtype).max - 1 return volume @@ -49,6 +49,26 @@ def ball_filter_imgs( ball_xy_size: int = 3, ball_z_size: int = 3, ) -> Tuple[np.ndarray, np.ndarray]: + """ + Apply ball filtering to a 3D volume and detect cell centres. + + Uses the `BallFilter` class to perform ball filtering on the volume + and the `CellDetector` class to detect cell centres. + + Args: + volume (np.ndarray): The 3D volume to be filtered. + threshold_value (int): The threshold value for ball filtering. + soma_centre_value (int): The value representing the soma centre. + ball_xy_size (int, optional): + The size of the ball filter in the XY plane. Defaults to 3. + ball_z_size (int, optional): + The size of the ball filter in the Z plane. Defaults to 3. + + Returns: + Tuple[np.ndarray, np.ndarray]: + A tuple containing the filtered volume and the cell centres. + + """ # OPTIMISE: reuse ball filter instance good_tiles_mask = np.ones((1, 1, volume.shape[2]), dtype=bool) @@ -71,10 +91,10 @@ def ball_filter_imgs( ) # FIXME: hard coded type - ball_filtered_volume = np.zeros(volume.shape, dtype=np.uint16) + ball_filtered_volume = np.zeros(volume.shape, dtype=np.uint32) previous_plane = None for z in range(volume.shape[2]): - bf.append(volume[:, :, z].astype(np.uint16), good_tiles_mask[:, :, z]) + bf.append(volume[:, :, z].astype(np.uint32), good_tiles_mask[:, :, z]) if bf.ready: bf.walk() middle_plane = bf.get_middle_plane() @@ -89,11 +109,24 @@ def ball_filter_imgs( def iterative_ball_filter( volume: np.ndarray, n_iter: int = 10 ) -> Tuple[List[int], List[np.ndarray]]: + """ + Apply iterative ball filtering to the given volume. + The volume is eroded at each iteration, by subtracting 1 from the volume. + + Parameters: + volume (np.ndarray): The input volume. + n_iter (int): The number of iterations to perform. Default is 10. + + Returns: + Tuple[List[int], List[np.ndarray]]: A tuple containing two lists: + The structures found in each iteration. + The cell centres found in each iteration. + """ ns = [] centres = [] - threshold_value = 65534 - soma_centre_value = 65535 + threshold_value = np.iinfo(volume.dtype).max - 1 + soma_centre_value = np.iinfo(volume.dtype).max vol = volume.copy() # TODO: check if required @@ -131,6 +164,21 @@ def check_centre_in_cuboid(centre: np.ndarray, max_coords: np.ndarray) -> bool: def split_cells( cell_points: np.ndarray, outlier_keep: bool = False ) -> np.ndarray: + """ + Split the given cell points into individual cell centres. + + Args: + cell_points (np.ndarray): Array of cell points with shape (N, 3), + where N is the number of cell points and each point is represented + by its x, y, and z coordinates. + outlier_keep (bool, optional): Flag indicating whether to keep outliers + during the splitting process. Defaults to False. + + Returns: + np.ndarray: Array of absolute cell centres with shape (M, 3), + where M is the number of individual cells and each centre is + represented by its x, y, and z coordinates. + """ orig_centre = get_structure_centre(cell_points) xs = cell_points[:, 0] diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index a29ada9f..727f0741 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -163,10 +163,10 @@ def test_data_dimension_error(ndim): # Check for an error when non-3D data input shape = (2, 3, 4, 5)[:ndim] signal_array = np.random.randint( - low=0, high=2**16, size=shape, dtype=np.uint16 + low=0, high=2**16, size=shape, dtype=np.uint32 ) background_array = np.random.randint( - low=0, high=2**16, size=shape, dtype=np.uint16 + low=0, high=2**16, size=shape, dtype=np.uint32 ) with pytest.raises(ValueError, match="Input data must be 3D"):