Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get_spectrum to no longer accept ROI as a string #2416

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,10 @@ def shuttercount_issue(self) -> str:
return "Need 2 different ShutterCount stacks"
return ""

def get_spectrum(self,
roi: str | SensibleROI,
mode: SpecType,
normalise_with_shuttercount: bool = False) -> np.ndarray:
def get_spectrum(self, roi: SensibleROI, mode: SpecType, normalise_with_shuttercount: bool = False) -> np.ndarray:
if self._stack is None:
return np.array([])

if isinstance(roi, str):
roi = self.get_roi(roi)

if mode == SpecType.SAMPLE:
return self.get_stack_spectrum(self._stack, roi)

Expand Down Expand Up @@ -330,36 +324,43 @@ def has_stack(self) -> bool:
"""
return self._stack is not None

def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: bool = False) -> None:
def save_csv(self,
path: Path,
rois: dict[str, SensibleROI],
normalise: bool,
normalise_with_shuttercount: bool = False) -> None:
"""
Iterates over all ROIs and saves the spectrum for each one to a CSV file.

@param path: The path to save the CSV file to.
@param normalized: Whether to save the normalized spectrum.

"""
if self._stack is None:
raise ValueError("No stack selected")
if not rois:
raise ValueError("No ROIs provided")

csv_output = CSVOutput()
csv_output.add_column("ToF_index", np.arange(self._stack.data.shape[0]), "Index")

self.tof_data = self.get_stack_time_of_flight()
if self.tof_data is not None:
self.units.set_data_to_convert(self.tof_data)
csv_output.add_column("Wavelength", self.units.tof_seconds_to_wavelength_in_angstroms(), "Angstrom")
csv_output.add_column("ToF", self.units.tof_seconds_to_us(), "Microseconds")
csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV")

for roi_name in self.get_list_of_roi_names():
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE, normalise_with_shuttercount),
for roi_name, roi in rois.items():
csv_output.add_column(roi_name, self.get_spectrum(roi, SpecType.SAMPLE, normalise_with_shuttercount),
"Counts")

if normalise:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts")
csv_output.add_column(roi_name + "_norm",
self.get_spectrum(roi_name, SpecType.SAMPLE_NORMED, normalise_with_shuttercount),
csv_output.add_column(f"{roi_name}_open", self.get_spectrum(roi, SpecType.OPEN), "Counts")
csv_output.add_column(f"{roi_name}_norm",
self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise_with_shuttercount),
"Counts")

with path.open("w") as outfile:
csv_output.write(outfile)
self.save_roi_coords(self.get_roi_coords_filename(path))
ashleymeigh2 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
70 changes: 41 additions & 29 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,15 @@ def handle_roi_moved(self, force_new_spectrums: bool = False) -> None:
Handle changes to any ROI position and size.
"""
for name in self.model.get_list_of_roi_names():
roi = self.view.spectrum_widget.get_roi(name)
if force_new_spectrums or roi != self.model.get_roi(name):
self.model.set_roi(name, roi)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))
view_roi = self.view.spectrum_widget.get_roi(name)
if force_new_spectrums or view_roi != self.model.get_roi(name):
self.model.set_roi(name, view_roi)
spectrum = self.model.get_spectrum(
view_roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled(),
)
self.view.set_spectrum(name, spectrum)

def handle_roi_clicked(self, roi: SpectrumROI) -> None:
if not roi.name == ROI_RITS:
Expand All @@ -220,9 +221,10 @@ def redraw_spectrum(self, name: str) -> None:
"""
Redraw the spectrum with the given name
"""
roi = self.model.get_roi(name)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))

Expand All @@ -231,12 +233,13 @@ def redraw_all_rois(self) -> None:
Redraw all ROIs and spectrum plots
"""
for name in self.model.get_list_of_roi_names():
if name == "all" or self.view.spectrum_widget.roi_dict[name].isVisible() is False:
if name == "all" or not self.view.spectrum_widget.roi_dict[name].isVisible():
continue
self.model.set_roi(name, self.view.spectrum_widget.get_roi(name))
roi = self.view.spectrum_widget.get_roi(name)
self.model.set_roi(name, roi)
self.view.set_spectrum(
name,
self.model.get_spectrum(name,
self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled()))

Expand All @@ -257,13 +260,17 @@ def handle_button_enabled(self) -> None:

def handle_export_csv(self) -> None:
path = self.view.get_csv_filename()
if path is None:
if not path:
return
path = path.with_suffix(".csv") if path.suffix != ".csv" else path
rois = {roi.name: roi.as_sensible_roi() for roi in self.view.spectrum_widget.roi_dict.values()}

if path.suffix != ".csv":
path = path.with_suffix(".csv")

self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, self.view.shuttercount_norm_enabled())
self.model.save_csv(
path,
rois,
normalise=self.spectrum_mode == SpecType.SAMPLE_NORMED,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled(),
)

def handle_rits_export(self) -> None:
"""
Expand Down Expand Up @@ -331,10 +338,12 @@ def do_add_roi(self) -> None:
roi_name = self.model.roi_name_generator()
if roi_name in self.view.spectrum_widget.roi_dict:
raise ValueError(f"ROI name already exists: {roi_name}")

self.model.set_new_roi(roi_name)
self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
roi = self.model.get_roi(roi_name)
self.view.spectrum_widget.add_roi(roi, roi_name)
spectrum = self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled())
self.view.set_spectrum(roi_name, spectrum)
self.view.auto_range_image()
self.do_add_roi_to_table(roi_name)

Expand All @@ -351,11 +360,11 @@ def change_roi_colour(self, roi_name: str, new_colour: tuple[int, int, int]) ->
self.view.on_visibility_change()

def add_rits_roi(self) -> None:
roi_name = ROI_RITS
self.model.set_new_roi(roi_name)
self.view.spectrum_widget.add_roi(self.model.get_roi(roi_name), roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
self.model.set_new_roi(ROI_RITS)
roi = self.model.get_roi(ROI_RITS)
self.view.spectrum_widget.add_roi(roi, ROI_RITS)
self.view.set_spectrum(ROI_RITS,
self.model.get_spectrum(roi, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
self.view.set_roi_visibility_flags(ROI_RITS, visible=False)

def do_add_roi_to_table(self, roi_name: str) -> None:
Expand Down Expand Up @@ -386,13 +395,16 @@ def do_remove_roi(self, roi_name: str | None = None) -> None:
"""
if roi_name is None:
self.view.clear_all_rois()
for roi in self.get_roi_names():
self.view.spectrum_widget.remove_roi(roi)
for name in self.get_roi_names():
self.view.spectrum_widget.remove_roi(name)
self.model.remove_all_roi()
else:
roi = self.model.get_roi(roi_name)
self.view.spectrum_widget.remove_roi(roi_name)
self.view.set_spectrum(
roi_name, self.model.get_spectrum(roi_name, self.spectrum_mode, self.view.shuttercount_norm_enabled()))
spectrum = self.model.get_spectrum(roi,
self.spectrum_mode,
normalise_with_shuttercount=self.view.shuttercount_norm_enabled())
self.view.set_spectrum(roi_name, spectrum)
self.model.remove_roi(roi_name)

def handle_export_tab_change(self, index: int) -> None:
Expand Down
12 changes: 12 additions & 0 deletions mantidimaging/gui/windows/spectrum_viewer/spectrum_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def adjust_spec_roi(self, roi: SensibleROI) -> None:
def rename_roi(self, new_name: str) -> None:
self._name = new_name

def as_sensible_roi(self) -> SensibleROI:
"""
Converts the SpectrumROI to a SensibleROI object.
"""
pos = self.pos()
size = self.size()
left, top = pos
width, height = size
right = left + width
bottom = top + height
return SensibleROI.from_list([left, top, right, bottom])


class SpectrumWidget(QWidget):
"""
Expand Down
51 changes: 32 additions & 19 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,20 @@ def test_get_averaged_image_range(self):

def test_get_spectrum(self):
stack, spectrum = self._set_sample_stack()

model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
self.assertEqual(model_spec.shape, (10, ))
npt.assert_array_equal(model_spec, spectrum)

def test_get_normalised_spectrum(self):
stack, spectrum = self._set_sample_stack()

normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2)
self.model.set_normalise_stack(normalise_stack)

model_open_spec = self.model.get_spectrum("roi", SpecType.OPEN)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_open_spec = self.model.get_spectrum(roi, SpecType.OPEN)
self.assertEqual(model_open_spec.shape, (10, ))
self.assertTrue(np.all(model_open_spec == 2))

model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED)
model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
self.assertEqual(model_norm_spec.shape, (10, ))
npt.assert_array_equal(model_norm_spec, spectrum / 2)

Expand All @@ -122,8 +120,8 @@ def test_get_normalised_spectrum_zeros(self):
normalise_stack = ImageStack(np.ones([10, 11, 12]) * 2)
normalise_stack.data[5] = 0
self.model.set_normalise_stack(normalise_stack)

model_norm_spec = self.model.get_spectrum("roi", SpecType.SAMPLE_NORMED)
roi = SensibleROI(left=0, top=0, right=12, bottom=11)
model_norm_spec = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
expected_spec = spectrum / 2
expected_spec[5] = 0
self.assertEqual(model_norm_spec.shape, (10, ))
Expand All @@ -134,8 +132,8 @@ def test_get_normalised_spectrum_different_size(self):

normalise_stack = ImageStack(np.ones([10, 11, 13]))
self.model.set_normalise_stack(normalise_stack)

error_spectrum = self.model.get_spectrum("all", SpecType.SAMPLE_NORMED)
roi = SensibleROI(left=0, top=0, right=13, bottom=11)
error_spectrum = self.model.get_spectrum(roi, SpecType.SAMPLE_NORMED)
np.testing.assert_array_equal(error_spectrum, np.array([]))

def test_normalise_issue(self):
Expand Down Expand Up @@ -168,12 +166,12 @@ def test_get_spectrum_roi(self):
stack, spectrum = self._set_sample_stack()
stack.data[:, :, 6:] *= 2

self.model.set_roi('roi', SensibleROI.from_list([0, 0, 3, 3]))
model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI.from_list([0, 0, 3, 3])
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
npt.assert_array_equal(model_spec, spectrum)

self.model.set_roi('roi', SensibleROI.from_list([6, 0, 6 + 3, 3]))
model_spec = self.model.get_spectrum("roi", SpecType.SAMPLE)
roi = SensibleROI.from_list([6, 0, 6 + 3, 3])
model_spec = self.model.get_spectrum(roi, SpecType.SAMPLE)
npt.assert_array_equal(model_spec, spectrum * 2)

def test_get_stack_spectrum(self):
Expand All @@ -191,9 +189,13 @@ def test_save_csv(self):
stack.data *= 2
self.model.set_normalise_stack(None)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
roi_specific = SensibleROI.from_list([0, 0, 3, 3])
rois = {"all": roi_all, "roi": roi_specific}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, False)
self.model.save_csv(mock_path, rois=rois, normalise=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,all,roi", mock_stream.captured[0])
Expand Down Expand Up @@ -320,18 +322,26 @@ def test_save_csv_norm_missing_stack(self):
stack, _ = self._set_sample_stack()
stack.data *= 2
self.model.set_normalise_stack(None)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
rois = {"all": roi_all}

with self.assertRaises(RuntimeError):
self.model.save_csv(mock.Mock(), True)
self.model.save_csv(mock.Mock(), rois=rois, normalise=True)

def test_save_csv_norm(self):
self._set_sample_stack()

open_stack = ImageStack(np.ones([10, 11, 12]) * 2)
self.model.set_normalise_stack(open_stack)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
roi_specific = SensibleROI.from_list([0, 0, 3, 3])
rois = {"all": roi_all, "roi": roi_specific}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, True)
self.model.save_csv(path=mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,all,all_open,all_norm,roi,roi_open,roi_norm", mock_stream.captured[0])
Expand All @@ -346,9 +356,12 @@ def test_save_csv_norm_with_tof_loaded(self):
stack.data[:, :, :5] *= 2
self.model.set_normalise_stack(norm)

roi_all = SensibleROI.from_list([0, 0, 12, 11])
rois = {"all": roi_all, "roi": roi_all}

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, True)
self.model.save_csv(mock_path, rois=rois, normalise=True, normalise_with_shuttercount=False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,Wavelength,ToF,Energy,all,all_open,all_norm,roi,roi_open,roi_norm",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock, mock_
self.view.get_csv_filename = mock.Mock(return_value=Path(path_name))
self.view.shuttercount_norm_enabled.return_value = False
mock_shuttercount_issue.return_value = "Something wrong"

self.presenter.model.set_stack(generate_images())

self.presenter.handle_export_csv()

self.view.get_csv_filename.assert_called_once()
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False, False)
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), {},
normalise=False,
normalise_with_shuttercount=False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def test_WHEN_colour_is_not_valid_THEN_roi_colour_is_unchanged(self):
self.spectrum_roi.onChangeColor()
self.assertEqual(self.spectrum_roi.colour, (0, 0, 0, 255))

def test_WHEN_as_sensible_roi_called_THEN_correct_sensible_roi_returned(self):
sensible_roi = self.spectrum_roi.as_sensible_roi()
self.assertEqual((sensible_roi.left, sensible_roi.top, sensible_roi.right, sensible_roi.bottom),
(10, 20, 30, 40))


@mock_versions
@start_qapplication
Expand Down
Loading