Skip to content

Commit 9c237fc

Browse files
authored
Merge pull request #2530 from oesteban/enh/add-masks-overlap
ENH: Revise the implementation of FuzzyOverlap
2 parents 5c3c413 + 140b159 commit 9c237fc

File tree

3 files changed

+125
-72
lines changed

3 files changed

+125
-72
lines changed

nipype/algorithms/metrics.py

Lines changed: 66 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
from .. import config, logging
2222

23-
from ..interfaces.base import (BaseInterface, traits, TraitedSpec, File,
24-
InputMultiPath, BaseInterfaceInputSpec,
25-
isdefined)
23+
from ..interfaces.base import (
24+
SimpleInterface, BaseInterface, traits, TraitedSpec, File,
25+
InputMultiPath, BaseInterfaceInputSpec,
26+
isdefined)
2627
from ..interfaces.nipy.base import NipyBaseInterface
27-
from ..utils import NUMPY_MMAP
2828

2929
iflogger = logging.getLogger('interface')
3030

@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383383
File(exists=True),
384384
mandatory=True,
385385
desc='Test image. Requires the same dimensions as in_ref.')
386+
in_mask = File(exists=True, desc='calculate overlap only within mask')
386387
weighting = traits.Enum(
387388
'none',
388389
'volume',
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403404
class FuzzyOverlapOutputSpec(TraitedSpec):
404405
jaccard = traits.Float(desc='Fuzzy Jaccard Index (fJI), all the classes')
405406
dice = traits.Float(desc='Fuzzy Dice Index (fDI), all the classes')
406-
diff_file = File(
407-
exists=True,
408-
desc=
409-
'resulting difference-map of all classes, using the chosen weighting')
410407
class_fji = traits.List(
411408
traits.Float(),
412409
desc='Array containing the fJIs of each computed class')
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415412
desc='Array containing the fDIs of each computed class')
416413

417414

418-
class FuzzyOverlap(BaseInterface):
415+
class FuzzyOverlap(SimpleInterface):
419416
"""Calculates various overlap measures between two maps, using the fuzzy
420417
definition proposed in: Crum et al., Generalized Overlap Measures for
421418
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface):
439436
output_spec = FuzzyOverlapOutputSpec
440437

441438
def _run_interface(self, runtime):
442-
ncomp = len(self.inputs.in_ref)
443-
assert (ncomp == len(self.inputs.in_tst))
444-
weights = np.ones(shape=ncomp)
445-
446-
img_ref = np.array([
447-
nb.load(fname, mmap=NUMPY_MMAP).get_data()
448-
for fname in self.inputs.in_ref
449-
])
450-
img_tst = np.array([
451-
nb.load(fname, mmap=NUMPY_MMAP).get_data()
452-
for fname in self.inputs.in_tst
453-
])
454-
455-
msk = np.sum(img_ref, axis=0)
456-
msk[msk > 0] = 1.0
457-
tst_msk = np.sum(img_tst, axis=0)
458-
tst_msk[tst_msk > 0] = 1.0
459-
460-
# check that volumes are normalized
461-
# img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462-
# img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463-
464-
self._jaccards = []
465-
volumes = []
466-
467-
diff_im = np.zeros(img_ref.shape)
468-
469-
for ref_comp, tst_comp, diff_comp in zip(img_ref, img_tst, diff_im):
470-
num = np.minimum(ref_comp, tst_comp)
471-
ddr = np.maximum(ref_comp, tst_comp)
472-
diff_comp[ddr > 0] += 1.0 - (num[ddr > 0] / ddr[ddr > 0])
473-
self._jaccards.append(np.sum(num) / np.sum(ddr))
474-
volumes.append(np.sum(ref_comp))
475-
476-
self._dices = 2.0 * (np.array(self._jaccards) /
477-
(np.array(self._jaccards) + 1.0))
439+
# Load data
440+
refdata = nb.concat_images(self.inputs.in_ref).get_data()
441+
tstdata = nb.concat_images(self.inputs.in_tst).get_data()
442+
443+
# Data must have same shape
444+
if not refdata.shape == tstdata.shape:
445+
raise RuntimeError(
446+
'Size of "in_tst" %s must match that of "in_ref" %s.' %
447+
(tstdata.shape, refdata.shape))
448+
449+
ncomp = refdata.shape[-1]
478450

451+
# Load mask
452+
mask = np.ones_like(refdata, dtype=bool)
453+
if isdefined(self.inputs.in_mask):
454+
mask = nb.load(self.inputs.in_mask).get_data()
455+
mask = mask > 0
456+
mask = np.repeat(mask[..., np.newaxis], ncomp, -1)
457+
assert mask.shape == refdata.shape
458+
459+
# Drop data outside mask
460+
refdata = refdata[mask]
461+
tstdata = tstdata[mask]
462+
463+
if np.any(refdata < 0.0):
464+
iflogger.warning('Negative values encountered in "in_ref" input, '
465+
'taking absolute values.')
466+
refdata = np.abs(refdata)
467+
468+
if np.any(tstdata < 0.0):
469+
iflogger.warning('Negative values encountered in "in_tst" input, '
470+
'taking absolute values.')
471+
tstdata = np.abs(tstdata)
472+
473+
if np.any(refdata > 1.0):
474+
iflogger.warning('Values greater than 1.0 found in "in_ref" input, '
475+
'scaling values.')
476+
refdata /= refdata.max()
477+
478+
if np.any(tstdata > 1.0):
479+
iflogger.warning('Values greater than 1.0 found in "in_tst" input, '
480+
'scaling values.')
481+
tstdata /= tstdata.max()
482+
483+
numerators = np.atleast_2d(
484+
np.minimum(refdata, tstdata).reshape((-1, ncomp)))
485+
denominators = np.atleast_2d(
486+
np.maximum(refdata, tstdata).reshape((-1, ncomp)))
487+
488+
jaccards = numerators.sum(axis=0) / denominators.sum(axis=0)
489+
490+
# Calculate weights
491+
weights = np.ones_like(jaccards, dtype=float)
479492
if self.inputs.weighting != "none":
480-
weights = 1.0 / np.array(volumes)
493+
volumes = np.sum((refdata + tstdata) > 0, axis=1).reshape((-1, ncomp))
494+
weights = 1.0 / volumes
481495
if self.inputs.weighting == "squared_vol":
482496
weights = weights**2
483497

484498
weights = weights / np.sum(weights)
499+
dices = 2.0 * jaccards / (jaccards + 1.0)
485500

486-
setattr(self, '_jaccard', np.sum(weights * self._jaccards))
487-
setattr(self, '_dice', np.sum(weights * self._dices))
488-
489-
diff = np.zeros(diff_im[0].shape)
490-
491-
for w, ch in zip(weights, diff_im):
492-
ch[msk == 0] = 0
493-
diff += w * ch
494-
495-
nb.save(
496-
nb.Nifti1Image(diff,
497-
nb.load(self.inputs.in_ref[0]).affine,
498-
nb.load(self.inputs.in_ref[0]).header),
499-
self.inputs.out_file)
500-
501+
# Fill-in the results object
502+
self._results['jaccard'] = float(weights.dot(jaccards))
503+
self._results['dice'] = float(weights.dot(dices))
504+
self._results['class_fji'] = [float(v) for v in jaccards]
505+
self._results['class_fdi'] = [float(v) for v in dices]
501506
return runtime
502507

503-
def _list_outputs(self):
504-
outputs = self._outputs().get()
505-
for method in ("dice", "jaccard"):
506-
outputs[method] = getattr(self, '_' + method)
507-
# outputs['volume_difference'] = self._volume
508-
outputs['diff_file'] = os.path.abspath(self.inputs.out_file)
509-
outputs['class_fji'] = np.array(self._jaccards).astype(float).tolist()
510-
outputs['class_fdi'] = self._dices.astype(float).tolist()
511-
return outputs
512-
513508

514509
class ErrorMapInputSpec(BaseInterfaceInputSpec):
515510
in_ref = File(

nipype/algorithms/tests/test_auto_FuzzyOverlap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def test_FuzzyOverlap_inputs():
1010
nohash=True,
1111
usedefault=True,
1212
),
13+
in_mask=dict(),
1314
in_ref=dict(mandatory=True, ),
1415
in_tst=dict(mandatory=True, ),
1516
out_file=dict(usedefault=True, ),
@@ -25,7 +26,6 @@ def test_FuzzyOverlap_outputs():
2526
class_fdi=dict(),
2627
class_fji=dict(),
2728
dice=dict(),
28-
diff_file=dict(),
2929
jaccard=dict(),
3030
)
3131
outputs = FuzzyOverlap.output_spec()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
4+
import numpy as np
5+
import nibabel as nb
6+
from nipype.testing import example_data
7+
from ..metrics import FuzzyOverlap
8+
9+
10+
def test_fuzzy_overlap(tmpdir):
11+
tmpdir.chdir()
12+
13+
# Tests with tissue probability maps
14+
in_mask = example_data('tpms_msk.nii.gz')
15+
tpms = [example_data('tpm_%02d.nii.gz' % i) for i in range(3)]
16+
out = FuzzyOverlap(in_ref=tpms[0], in_tst=tpms[0]).run().outputs
17+
assert out.dice == 1
18+
19+
out = FuzzyOverlap(
20+
in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[0]).run().outputs
21+
assert out.dice == 1
22+
23+
out = FuzzyOverlap(
24+
in_mask=in_mask, in_ref=tpms[0], in_tst=tpms[1]).run().outputs
25+
assert 0 < out.dice < 1
26+
27+
out = FuzzyOverlap(in_ref=tpms, in_tst=tpms).run().outputs
28+
assert out.dice == 1.0
29+
30+
out = FuzzyOverlap(
31+
in_mask=in_mask, in_ref=tpms, in_tst=tpms).run().outputs
32+
assert out.dice == 1.0
33+
34+
# Tests with synthetic 3x3x3 images
35+
data = np.zeros((3, 3, 3), dtype=float)
36+
data[0, 0, 0] = 0.5
37+
data[2, 2, 2] = 0.25
38+
data[1, 1, 1] = 0.3
39+
nb.Nifti1Image(data, np.eye(4)).to_filename('test1.nii.gz')
40+
41+
data = np.zeros((3, 3, 3), dtype=float)
42+
data[0, 0, 0] = 0.6
43+
data[1, 1, 1] = 0.3
44+
nb.Nifti1Image(data, np.eye(4)).to_filename('test2.nii.gz')
45+
46+
out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz').run().outputs
47+
assert np.allclose(out.dice, 0.82051)
48+
49+
# Just considering the mask, the central pixel
50+
# that raised the index now is left aside.
51+
data = np.zeros((3, 3, 3), dtype=int)
52+
data[0, 0, 0] = 1
53+
data[2, 2, 2] = 1
54+
nb.Nifti1Image(data, np.eye(4)).to_filename('mask.nii.gz')
55+
56+
out = FuzzyOverlap(in_ref='test1.nii.gz', in_tst='test2.nii.gz',
57+
in_mask='mask.nii.gz').run().outputs
58+
assert np.allclose(out.dice, 0.74074)

0 commit comments

Comments
 (0)