20
20
21
21
from .. import config , logging
22
22
23
- from ..interfaces .base import (BaseInterface , traits , TraitedSpec , File ,
24
- InputMultiPath , BaseInterfaceInputSpec ,
25
- isdefined )
23
+ from ..interfaces .base import (
24
+ SimpleInterface , BaseInterface , traits , TraitedSpec , File ,
25
+ InputMultiPath , BaseInterfaceInputSpec ,
26
+ isdefined )
26
27
from ..interfaces .nipy .base import NipyBaseInterface
27
- from ..utils import NUMPY_MMAP
28
28
29
29
iflogger = logging .getLogger ('interface' )
30
30
@@ -383,6 +383,7 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
383
383
File (exists = True ),
384
384
mandatory = True ,
385
385
desc = 'Test image. Requires the same dimensions as in_ref.' )
386
+ in_mask = File (exists = True , desc = 'calculate overlap only within mask' )
386
387
weighting = traits .Enum (
387
388
'none' ,
388
389
'volume' ,
@@ -403,10 +404,6 @@ class FuzzyOverlapInputSpec(BaseInterfaceInputSpec):
403
404
class FuzzyOverlapOutputSpec (TraitedSpec ):
404
405
jaccard = traits .Float (desc = 'Fuzzy Jaccard Index (fJI), all the classes' )
405
406
dice = traits .Float (desc = 'Fuzzy Dice Index (fDI), all the classes' )
406
- diff_file = File (
407
- exists = True ,
408
- desc =
409
- 'resulting difference-map of all classes, using the chosen weighting' )
410
407
class_fji = traits .List (
411
408
traits .Float (),
412
409
desc = 'Array containing the fJIs of each computed class' )
@@ -415,7 +412,7 @@ class FuzzyOverlapOutputSpec(TraitedSpec):
415
412
desc = 'Array containing the fDIs of each computed class' )
416
413
417
414
418
- class FuzzyOverlap (BaseInterface ):
415
+ class FuzzyOverlap (SimpleInterface ):
419
416
"""Calculates various overlap measures between two maps, using the fuzzy
420
417
definition proposed in: Crum et al., Generalized Overlap Measures for
421
418
Evaluation and Validation in Medical Image Analysis, IEEE Trans. Med.
@@ -439,77 +436,75 @@ class FuzzyOverlap(BaseInterface):
439
436
output_spec = FuzzyOverlapOutputSpec
440
437
441
438
def _run_interface (self , runtime ):
442
- ncomp = len (self .inputs .in_ref )
443
- assert (ncomp == len (self .inputs .in_tst ))
444
- weights = np .ones (shape = ncomp )
445
-
446
- img_ref = np .array ([
447
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
448
- for fname in self .inputs .in_ref
449
- ])
450
- img_tst = np .array ([
451
- nb .load (fname , mmap = NUMPY_MMAP ).get_data ()
452
- for fname in self .inputs .in_tst
453
- ])
454
-
455
- msk = np .sum (img_ref , axis = 0 )
456
- msk [msk > 0 ] = 1.0
457
- tst_msk = np .sum (img_tst , axis = 0 )
458
- tst_msk [tst_msk > 0 ] = 1.0
459
-
460
- # check that volumes are normalized
461
- # img_ref[:][msk>0] = img_ref[:][msk>0] / (np.sum( img_ref, axis=0 ))[msk>0]
462
- # img_tst[tst_msk>0] = img_tst[tst_msk>0] / np.sum( img_tst, axis=0 )[tst_msk>0]
463
-
464
- self ._jaccards = []
465
- volumes = []
466
-
467
- diff_im = np .zeros (img_ref .shape )
468
-
469
- for ref_comp , tst_comp , diff_comp in zip (img_ref , img_tst , diff_im ):
470
- num = np .minimum (ref_comp , tst_comp )
471
- ddr = np .maximum (ref_comp , tst_comp )
472
- diff_comp [ddr > 0 ] += 1.0 - (num [ddr > 0 ] / ddr [ddr > 0 ])
473
- self ._jaccards .append (np .sum (num ) / np .sum (ddr ))
474
- volumes .append (np .sum (ref_comp ))
475
-
476
- self ._dices = 2.0 * (np .array (self ._jaccards ) /
477
- (np .array (self ._jaccards ) + 1.0 ))
439
+ # Load data
440
+ refdata = nb .concat_images (self .inputs .in_ref ).get_data ()
441
+ tstdata = nb .concat_images (self .inputs .in_tst ).get_data ()
442
+
443
+ # Data must have same shape
444
+ if not refdata .shape == tstdata .shape :
445
+ raise RuntimeError (
446
+ 'Size of "in_tst" %s must match that of "in_ref" %s.' %
447
+ (tstdata .shape , refdata .shape ))
448
+
449
+ ncomp = refdata .shape [- 1 ]
478
450
451
+ # Load mask
452
+ mask = np .ones_like (refdata , dtype = bool )
453
+ if isdefined (self .inputs .in_mask ):
454
+ mask = nb .load (self .inputs .in_mask ).get_data ()
455
+ mask = mask > 0
456
+ mask = np .repeat (mask [..., np .newaxis ], ncomp , - 1 )
457
+ assert mask .shape == refdata .shape
458
+
459
+ # Drop data outside mask
460
+ refdata = refdata [mask ]
461
+ tstdata = tstdata [mask ]
462
+
463
+ if np .any (refdata < 0.0 ):
464
+ iflogger .warning ('Negative values encountered in "in_ref" input, '
465
+ 'taking absolute values.' )
466
+ refdata = np .abs (refdata )
467
+
468
+ if np .any (tstdata < 0.0 ):
469
+ iflogger .warning ('Negative values encountered in "in_tst" input, '
470
+ 'taking absolute values.' )
471
+ tstdata = np .abs (tstdata )
472
+
473
+ if np .any (refdata > 1.0 ):
474
+ iflogger .warning ('Values greater than 1.0 found in "in_ref" input, '
475
+ 'scaling values.' )
476
+ refdata /= refdata .max ()
477
+
478
+ if np .any (tstdata > 1.0 ):
479
+ iflogger .warning ('Values greater than 1.0 found in "in_tst" input, '
480
+ 'scaling values.' )
481
+ tstdata /= tstdata .max ()
482
+
483
+ numerators = np .atleast_2d (
484
+ np .minimum (refdata , tstdata ).reshape ((- 1 , ncomp )))
485
+ denominators = np .atleast_2d (
486
+ np .maximum (refdata , tstdata ).reshape ((- 1 , ncomp )))
487
+
488
+ jaccards = numerators .sum (axis = 0 ) / denominators .sum (axis = 0 )
489
+
490
+ # Calculate weights
491
+ weights = np .ones_like (jaccards , dtype = float )
479
492
if self .inputs .weighting != "none" :
480
- weights = 1.0 / np .array (volumes )
493
+ volumes = np .sum ((refdata + tstdata ) > 0 , axis = 1 ).reshape ((- 1 , ncomp ))
494
+ weights = 1.0 / volumes
481
495
if self .inputs .weighting == "squared_vol" :
482
496
weights = weights ** 2
483
497
484
498
weights = weights / np .sum (weights )
499
+ dices = 2.0 * jaccards / (jaccards + 1.0 )
485
500
486
- setattr (self , '_jaccard' , np .sum (weights * self ._jaccards ))
487
- setattr (self , '_dice' , np .sum (weights * self ._dices ))
488
-
489
- diff = np .zeros (diff_im [0 ].shape )
490
-
491
- for w , ch in zip (weights , diff_im ):
492
- ch [msk == 0 ] = 0
493
- diff += w * ch
494
-
495
- nb .save (
496
- nb .Nifti1Image (diff ,
497
- nb .load (self .inputs .in_ref [0 ]).affine ,
498
- nb .load (self .inputs .in_ref [0 ]).header ),
499
- self .inputs .out_file )
500
-
501
+ # Fill-in the results object
502
+ self ._results ['jaccard' ] = float (weights .dot (jaccards ))
503
+ self ._results ['dice' ] = float (weights .dot (dices ))
504
+ self ._results ['class_fji' ] = [float (v ) for v in jaccards ]
505
+ self ._results ['class_fdi' ] = [float (v ) for v in dices ]
501
506
return runtime
502
507
503
- def _list_outputs (self ):
504
- outputs = self ._outputs ().get ()
505
- for method in ("dice" , "jaccard" ):
506
- outputs [method ] = getattr (self , '_' + method )
507
- # outputs['volume_difference'] = self._volume
508
- outputs ['diff_file' ] = os .path .abspath (self .inputs .out_file )
509
- outputs ['class_fji' ] = np .array (self ._jaccards ).astype (float ).tolist ()
510
- outputs ['class_fdi' ] = self ._dices .astype (float ).tolist ()
511
- return outputs
512
-
513
508
514
509
class ErrorMapInputSpec (BaseInterfaceInputSpec ):
515
510
in_ref = File (
0 commit comments