Skip to content

Commit

Permalink
Compute NaN Function (#2086)
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc authored Mar 8, 2024
2 parents 10d7e40 + e6f0eef commit 589ecfc
Showing 1 changed file with 34 additions and 37 deletions.
71 changes: 34 additions & 37 deletions mantidimaging/core/operations/nan_removal/nan_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from __future__ import annotations

from functools import partial
from logging import getLogger
from typing import Dict, TYPE_CHECKING

import numpy as np
import scipy.ndimage as scipy_ndimage
from scipy.ndimage import median_filter

from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility.qt_helpers import Type

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,16 +50,27 @@ def filter_func(data, replace_value=None, mode_value="Constant", progress=None)
"""

if mode_value == "Constant":
sample = data.data
nan_idxs = np.isnan(sample)
sample[nan_idxs] = replace_value
params = {'replace_value': replace_value}
ps.run_compute_func(NaNRemovalFilter.compute_constant_function, data.data.shape[0], data.shared_array,
params, progress)
elif mode_value == "Median":
_execute(data, 3, "reflect", progress)
ps.run_compute_func(NaNRemovalFilter.compute_median_function, data.data.shape[0], data.shared_array, {},
progress)
else:
raise ValueError(f"Unknown mode: '{mode_value}'\nShould be one of {NaNRemovalFilter.MODES}")
raise ValueError(f"Unknown mode: '{mode_value}'. Should be one of {NaNRemovalFilter.MODES}")

return data

@staticmethod
def compute_constant_function(i: int, array: np.ndarray, params: dict):
replace_value = params['replace_value']
nan_idxs = np.isnan(array[i])
array[i][nan_idxs] = replace_value

@staticmethod
def compute_median_function(i: int, array: np.ndarray, params: dict):
array[i] = NaNRemovalFilter._nan_to_median(array[i], size=3, edgemode='reflect')

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindowView') -> Dict[str, 'QWidget']:
from mantidimaging.gui.utility import add_property_to_form
Expand All @@ -87,37 +96,25 @@ def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindow

return {"mode_field": mode_field, "replace_value_field": replace_value_field}

@staticmethod
def _nan_to_median(data: np.ndarray, size: int, edgemode: str):
"""
Replaces NaN values in data with median, based on a kernel 'size' and 'edgemode'.
Initially converts NaNs to -inf to avoid calculation issues, applies a median filter.
After -inf changes back to NaNs to indicate unprocessed blocks.
"""
nans = np.isnan(data)
if np.any(nans):
median_data = np.where(nans, -np.inf, data)
median_data = median_filter(median_data, size=size, mode=edgemode)
data = np.where(nans, median_data, data)
if np.any(data == -np.inf):
data = np.where(np.logical_and(nans, data == -np.inf), np.nan, data)

return data

@staticmethod
def execute_wrapper(mode_field=None, replace_value_field=None):
mode_value = mode_field.currentText()
replace_value = replace_value_field.value()
return partial(NaNRemovalFilter.filter_func, replace_value=replace_value, mode_value=mode_value)


def _nan_to_median(data: np.ndarray, size: int, edgemode: str):
nans = np.isnan(data)
if np.any(nans):
median_data = np.where(nans, -np.inf, data)
median_data = scipy_ndimage.median_filter(median_data, size=size, mode=edgemode)
data = np.where(nans, median_data, data)

if np.any(data == -np.inf):
# Convert any left over -infs back to NaNs
data = np.where(np.logical_and(nans, data == -np.inf), np.nan, data)

return data


def _execute(images: ImageStack, size, edgemode, progress=None):
log = getLogger(__name__)
progress = Progress.ensure_instance(progress, task_name='NaN Removal')

# create the partial function to forward the parameters
f = ps.create_partial(_nan_to_median, ps.return_to_self, size=size, edgemode=edgemode)

with progress:
log.info("PARALLEL NaN Removal filter, with pixel data type: {0}".format(images.dtype))

ps.execute(f, [images.shared_array], images.data.shape[0], progress, msg="NaN Removal")

return images

0 comments on commit 589ecfc

Please sign in to comment.