Skip to content

Commit

Permalink
updated outputs for template building
Browse files Browse the repository at this point in the history
  • Loading branch information
astewartau committed Jun 11, 2021
1 parent cb3ba7b commit 7cbba38
Show file tree
Hide file tree
Showing 3 changed files with 426 additions and 32 deletions.
63 changes: 31 additions & 32 deletions run_4_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import glob
import argparse
import psutil
import re

import nipype.interfaces.utility as util
import nipype.interfaces.ants as ants
import nipype.interfaces.io as io
import nipype.pipeline.engine as pe

from niflow.nipype1.workflows.smri.ants import ANTSTemplateBuildSingleIterationWF
from scripts.antsBuildTemplate import ANTSTemplateBuildSingleIterationWF

def init_workflow(magnitude_images, qsm_images):

Expand All @@ -21,13 +22,14 @@ def init_workflow(magnitude_images, qsm_images):
# datasource
datasource = pe.Node(
interface=util.IdentityInterface(
fields=['imageList', 'passiveImagesDictionariesList']
fields=['magnitude_images', 'qsm_images', 'qsm_dict']
),
run_without_submitting=True,
name='InputImages'
)
datasource.inputs.imageList = magnitude_images
datasource.inputs.passiveImagesDictionariesList = qsm_images
datasource.inputs.magnitude_images = magnitude_images
datasource.inputs.qsm_images = qsm_images
datasource.inputs.qsm_dict = [{'QSM' : x} for x in qsm_images]
datasource.inputs.sort_filelist = True

# initial average
Expand All @@ -38,15 +40,15 @@ def init_workflow(magnitude_images, qsm_images):
initAvg.inputs.dimension = 3
initAvg.inputs.normalize = True
wf.connect([
(datasource, initAvg, [('imageList', 'images')])
(datasource, initAvg, [('magnitude_images', 'images')])
])

# first iteration
buildTemplateIteration1 = ANTSTemplateBuildSingleIterationWF('iteration01')
wf.connect([
(initAvg, buildTemplateIteration1, [('output_average_image', 'inputspec.fixed_image')]),
(datasource, buildTemplateIteration1, [('imageList', 'inputspec.images')]),
(datasource, buildTemplateIteration1, [('passiveImagesDictionariesList', 'inputspec.ListOfPassiveImagesDictionaries')]),
(datasource, buildTemplateIteration1, [('magnitude_images', 'inputspec.images')]),
(datasource, buildTemplateIteration1, [('qsm_dict', 'inputspec.ListOfPassiveImagesDictionaries')]),
])
BeginANTS1 = buildTemplateIteration1.get_node("BeginANTS")
BeginANTS1.plugin_args = {
Expand All @@ -58,8 +60,8 @@ def init_workflow(magnitude_images, qsm_images):
buildTemplateIteration2 = ANTSTemplateBuildSingleIterationWF('iteration02')
wf.connect([
(buildTemplateIteration1, buildTemplateIteration2, [('outputspec.template', 'inputspec.fixed_image')]),
(datasource, buildTemplateIteration2, [('imageList', 'inputspec.images')]),
(datasource, buildTemplateIteration2, [('passiveImagesDictionariesList', 'inputspec.ListOfPassiveImagesDictionaries')])
(datasource, buildTemplateIteration2, [('magnitude_images', 'inputspec.images')]),
(datasource, buildTemplateIteration2, [('qsm_dict', 'inputspec.ListOfPassiveImagesDictionaries')])
])
BeginANTS2 = buildTemplateIteration2.get_node("BeginANTS")
BeginANTS2.plugin_args = {
Expand All @@ -69,14 +71,16 @@ def init_workflow(magnitude_images, qsm_images):

# datasink
datasink = pe.Node(
io.DataSink(),
io.DataSink(base_directory=args.out_dir),
name="datasink"
)
datasink.inputs.base_directory = os.path.join('out/test', "results")
wf.connect([
(buildTemplateIteration2, datasink, [('outputspec.template', 'PrimaryTemplate')]),
(buildTemplateIteration2, datasink, [('outputspec.passive_deformed_templates', 'PassiveTemplate')]),
(initAvg, datasink, [('output_average_image', 'PreRegisterAverage')]),
(initAvg, datasink, [('output_average_image', 'initial_average')]),
(buildTemplateIteration2, datasink, [('outputspec.template', 'magnitude_template')]),
(buildTemplateIteration2, datasink, [('outputspec.passive_deformed_templates', 'qsm_template')]),
(buildTemplateIteration2, datasink, [('outputspec.flattened_transforms', 'transforms')]),
(buildTemplateIteration2, datasink, [('outputspec.wimtPassivedeformed', 'qsms_transformed')])
])

return wf
Expand Down Expand Up @@ -112,6 +116,12 @@ def init_workflow(magnitude_images, qsm_images):
help='NiPype working directory; defaults to \'work\' within \'out_dir\'.'
)

parser.add_argument(
'--qsm_pattern',
default=os.path.join('qsm_final', '*', '*.nii'),
help='Pattern used to match QSM images in qsm_dir'
)

parser.add_argument(
'--subject_pattern',
default='sub*',
Expand All @@ -131,20 +141,6 @@ def init_workflow(magnitude_images, qsm_images):
'The {subject}, {session} and {run} placeholders must be present.'
)

parser.add_argument(
'--subjects',
default=None,
nargs='*',
help='List of subject folders to process; by default all subjects are processed.'
)

parser.add_argument(
'--sessions',
default=None,
nargs='*',
help='List of session folders to process; by default all sessions are processed.'
)

parser.add_argument(
'--pbs',
default=None,
Expand Down Expand Up @@ -175,6 +171,10 @@ def init_workflow(magnitude_images, qsm_images):
args.work_dir = os.path.abspath(args.work_dir)
args.out_dir = os.path.abspath(args.out_dir)

# environment variables for multi-threading
os.environ["OMP_NUM_THREADS"] = "6"
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "6"

# set number of concurrent processes to run depending on
# available CPUs and RAM (max 1 per 3 GB of available RAM)
n_cpus = int(os.environ["NCPUS"]) if "NCPUS" in os.environ else int(os.cpu_count())
Expand All @@ -184,18 +184,17 @@ def init_workflow(magnitude_images, qsm_images):

# find input images
magnitude_pattern = os.path.join(args.bids_dir, args.magnitude_pattern.format(subject=args.subject_pattern, session=args.session_pattern, run='*'))
qsm_pattern = os.path.join(args.qsm_dir, "qsm_final", "*", "*.nii*")
qsm_pattern = os.path.join(args.qsm_dir, args.qsm_pattern)
magnitude_images = sorted(glob.glob(magnitude_pattern))
magnitude_images = [x for x in magnitude_images if 'E01' in x or not re.findall('E[0-9]{2}', x)]
qsm_images = sorted(glob.glob(qsm_pattern))


if len(magnitude_images) != len(qsm_images):
print(f"QSMxT: Error: Number of QSM images ({len(qsm_images)}) and magnitude images ({len(magnitude_images)}) do not match.")
print(f"Final QSM pattern: {qsm_pattern}")
print(f"Final magintude pattern: {magnitude_pattern}")
exit()

# convert qsm_images to dictionary
qsm_images = [{'QSM' : x} for x in qsm_images]

wf = init_workflow(magnitude_images, qsm_images)

if args.qsub_account_string:
Expand Down
File renamed without changes.
Loading

0 comments on commit 7cbba38

Please sign in to comment.