Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute NaN Function #2086

Merged
merged 3 commits into from
Mar 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions mantidimaging/core/operations/nan_removal/nan_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from __future__ import annotations

from functools import partial
from logging import getLogger
from typing import Dict, TYPE_CHECKING

import numpy as np
import scipy.ndimage as scipy_ndimage
from scipy.ndimage import median_filter

from mantidimaging.core.operations.base_filter import BaseFilter
from mantidimaging.core.parallel import shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility.qt_helpers import Type

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,16 +50,27 @@ def filter_func(data, replace_value=None, mode_value="Constant", progress=None)
"""

if mode_value == "Constant":
sample = data.data
nan_idxs = np.isnan(sample)
sample[nan_idxs] = replace_value
params = {'replace_value': replace_value}
ps.run_compute_func(NaNRemovalFilter.compute_constant_function, data.data.shape[0], data.shared_array,
params, progress)
elif mode_value == "Median":
_execute(data, 3, "reflect", progress)
ps.run_compute_func(NaNRemovalFilter.compute_median_function, data.data.shape[0], data.shared_array, {},
progress)
else:
raise ValueError(f"Unknown mode: '{mode_value}'\nShould be one of {NaNRemovalFilter.MODES}")
raise ValueError(f"Unknown mode: '{mode_value}'. Should be one of {NaNRemovalFilter.MODES}")

return data

@staticmethod
def compute_constant_function(i: int, array: np.ndarray, params: dict):
replace_value = params['replace_value']
nan_idxs = np.isnan(array[i])
array[i][nan_idxs] = replace_value

@staticmethod
def compute_median_function(i: int, array: np.ndarray, params: dict):
array[i] = NaNRemovalFilter._nan_to_median(array[i], size=3, edgemode='reflect')

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindowView') -> Dict[str, 'QWidget']:
from mantidimaging.gui.utility import add_property_to_form
Expand All @@ -87,37 +96,20 @@ def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindow

return {"mode_field": mode_field, "replace_value_field": replace_value_field}

@staticmethod
def _nan_to_median(data: np.ndarray, size: int, edgemode: str):
nans = np.isnan(data)
if np.any(nans):
median_data = np.where(nans, -np.inf, data)
median_data = median_filter(median_data, size=size, mode=edgemode)
data = np.where(nans, median_data, data)
if np.any(data == -np.inf):
data = np.where(np.logical_and(nans, data == -np.inf), np.nan, data)

return data

@staticmethod
def execute_wrapper(mode_field=None, replace_value_field=None):
mode_value = mode_field.currentText()
replace_value = replace_value_field.value()
return partial(NaNRemovalFilter.filter_func, replace_value=replace_value, mode_value=mode_value)


def _nan_to_median(data: np.ndarray, size: int, edgemode: str):
nans = np.isnan(data)
if np.any(nans):
samtygier-stfc marked this conversation as resolved.
Show resolved Hide resolved
median_data = np.where(nans, -np.inf, data)
median_data = scipy_ndimage.median_filter(median_data, size=size, mode=edgemode)
data = np.where(nans, median_data, data)

if np.any(data == -np.inf):
# Convert any left over -infs back to NaNs
data = np.where(np.logical_and(nans, data == -np.inf), np.nan, data)

return data


def _execute(images: ImageStack, size, edgemode, progress=None):
log = getLogger(__name__)
progress = Progress.ensure_instance(progress, task_name='NaN Removal')

# create the partial function to forward the parameters
f = ps.create_partial(_nan_to_median, ps.return_to_self, size=size, edgemode=edgemode)

with progress:
log.info("PARALLEL NaN Removal filter, with pixel data type: {0}".format(images.dtype))

ps.execute(f, [images.shared_array], images.data.shape[0], progress, msg="NaN Removal")

return images
Loading