Skip to content

Commit

Permalink
Merge pull request #210 from desihub/no-mpi-split
Browse files Browse the repository at this point in the history
completely eliminate the use of subcommunicators
  • Loading branch information
moustakas authored Feb 1, 2025
2 parents b5cfa4e + 7feda1e commit ce005d7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 126 deletions.
157 changes: 31 additions & 126 deletions bin/mpi-fastspecfit
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,6 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=F
templateversion=None, fphotodir=None, fphotofile=None):
"""Main wrapper to run fastspec, fastphot, or fastqa.
Top-level MPI paraellelization is over (e.g., healpix) files, but there is
another level of parallelization which makes use of subcommunicators.
For example, calling `mpi-fastspecfit` with 8 MPI tasks and --mp=4 will
result in two (8/4) healpix files being processed simultaneously
(specifically, by ranks 0 and 4) and then a further level of
parallelization over the objects in each of those files specifically, but
subranks (0, 1, 2, 3) and (4, 5, 6, 7), respectively.
"""
import sys
from desispec.parallel import stdouterr_redirected, weighted_partition
Expand All @@ -102,7 +93,6 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=F
size = comm.size
else:
rank, size = 0, 1
subcomm = None

if rank == 0:
t0 = time.time()
Expand Down Expand Up @@ -130,81 +120,31 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=F
return

if comm:
# Split the MPI.COMM_WORLD communicator into subcommunicators (of size
# args.mp) so we can MPI-parallelize over objects.
allranks = np.arange(comm.size)
if args.purempi:
colors = np.arange(comm.size) // args.mp
color = rank // args.mp
else:
colors = np.arange(comm.size)
color = rank

subcomm = comm.Split(color=color, key=rank)

if rank == 0:
if args.purempi:
subranks0 = allranks[::args.mp] # rank=0 in each subcomm
log.info(f'Rank {rank}: dividing filelist into {len(subranks0):,d} sub-communicator(s) ' + \
f'(size={comm.size:,d}, mp={args.mp}).')
else:
subranks0 = allranks
log.info(f'Rank {rank}: dividing filelist across {comm.size:,d} ranks.')
else:
subranks0 = None

subranks0 = comm.bcast(subranks0, root=0)

# Send the filelists and number of targets to each subrank0.
if rank == 0:
groups = weighted_partition(all_ntargets, len(subranks0))
for irank in range(1, len(subranks0)):
log.debug(f'Rank {rank} sending work to rank {subranks0[irank]}')
comm.send(all_redrockfiles[groups[irank]], dest=subranks0[irank], tag=1)
comm.send(all_outfiles[groups[irank]], dest=subranks0[irank], tag=2)
comm.send(all_ntargets[groups[irank]], dest=subranks0[irank], tag=3)
redrockfiles = all_redrockfiles[groups[rank]]
outfiles = all_outfiles[groups[rank]]
ntargets = all_ntargets[groups[rank]]
groups = weighted_partition(all_ntargets, size)
for irank in range(1, size):
log.debug(f'Rank {rank} sending work to rank {irank}')
comm.send(all_redrockfiles[groups[irank]], dest=irank, tag=1)
comm.send(all_outfiles[groups[irank]], dest=irank, tag=2)
comm.send(all_ntargets[groups[irank]], dest=irank, tag=3)
# rank 0 gets work, too
redrockfiles = all_redrockfiles[groups[0]]
outfiles = all_outfiles[groups[0]]
ntargets = all_ntargets[groups[0]]
else:
if rank in subranks0:
log.debug(f'Rank {rank}: received work from rank 0')
redrockfiles = comm.recv(source=0, tag=1)
outfiles = comm.recv(source=0, tag=2)
ntargets = comm.recv(source=0, tag=3)

# Each subrank0 sends work to the subranks it controls.
if subcomm.rank == 0:
subranks = allranks[np.isin(colors, color)]
# process from smallest to largest
srt = np.argsort(ntargets)#[::-1]
redrockfiles = redrockfiles[srt]
outfiles = outfiles[srt]
ntargets = ntargets[srt]
for irank in range(1, subcomm.size):
log.debug(f'Subrank 0 (rank {rank}) sending work to subrank {irank} (rank {subranks[irank]})')
subcomm.send(redrockfiles, dest=irank, tag=4)
subcomm.send(outfiles, dest=irank, tag=5)
subcomm.send(ntargets, dest=irank, tag=6)
else:
redrockfiles = subcomm.recv(source=0, tag=4)
outfiles = subcomm.recv(source=0, tag=5)
ntargets = subcomm.recv(source=0, tag=6)
log.debug(f'Rank {rank}: received work from rank 0')
redrockfiles = comm.recv(source=0, tag=1)
outfiles = comm.recv(source=0, tag=2)
ntargets = comm.recv(source=0, tag=3)
else:
redrockfiles = all_redrockfiles
outfiles = all_outfiles
ntargets = all_ntargets
#print(f'Rank={comm.rank}, subrank={subcomm.rank}, redrockfiles={redrockfiles}, ntargets={ntargets}')


# loop on each file
for redrockfile, outfile, ntarget in zip(redrockfiles, outfiles, ntargets):
if subcomm:
if subcomm.rank == 0:
if args.purempi:
log.debug(f'Rank {rank} (subrank {subcomm.rank}) started ' + \
f'at {time.asctime()}')
else:
log.debug(f'Rank {rank} started at {time.asctime()}')
elif rank == 0:
if rank == 0:
log.debug(f'Rank {rank} started at {time.asctime()}')

if args.makeqa:
Expand All @@ -218,67 +158,35 @@ def run_fastspecfit(args, comm=None, fastphot=False, specprod_dir=None, makeqa=F
cmd, cmdargs, logfile = build_cmdargs(args, redrockfile, outfile, sample=sample,
fastphot=fastphot, input_redshifts=input_redshifts)

if subcomm:
if subcomm.rank == 0:
if args.purempi:
log.info(f'Rank {rank} (nsubrank={subcomm.size}): ' + \
f'ntargets={ntarget}: {cmd} {cmdargs}')
else:
log.info(f'Rank {rank} ntargets={ntarget}: {cmd} {cmdargs}')
elif rank == 0:
if rank == 0:
log.info(f'Rank {rank}: ntargets={ntarget}: {cmd} {cmdargs}')

if args.dry_run:
continue

try:
if subcomm:
if subcomm.rank == 0:
t1 = time.time()
outdir = os.path.dirname(logfile)
if not os.path.isdir(outdir):
os.makedirs(outdir, exist_ok=True)
elif rank == 0:
t1 = time.time()
outdir = os.path.dirname(logfile)
if not os.path.isdir(outdir):
os.makedirs(outdir, exist_ok=True)
t1 = time.time()
outdir = os.path.dirname(logfile)
if not os.path.isdir(outdir):
os.makedirs(outdir, exist_ok=True)

if args.nolog:
if args.purempi:
err = fast(args=cmdargs.split(), comm=subcomm)
else:
err = fast(args=cmdargs.split(), comm=None)
err = fast(args=cmdargs.split(), comm=None)
else:
with stdouterr_redirected(to=logfile, overwrite=args.overwrite, comm=subcomm):
if args.purempi:
err = fast(args=cmdargs.split(), comm=subcomm)
else:
err = fast(args=cmdargs.split(), comm=None)

if subcomm:
if subcomm.rank == 0:
log.info(f'Rank {rank} done in {time.time() - t1:.2f} sec')
if err != 0:
if not os.path.exists(outfile):
log.warning(f'Rank {rank} missing {outfile}')
raise IOError
elif rank == 0:
log.info(f'Rank {rank} done in {time.time() - t1:.2f} sec')
if err != 0:
if not os.path.exists(outfile):
log.warning(f'Rank {rank} missing {outfile}')
raise IOError
with stdouterr_redirected(to=logfile, overwrite=args.overwrite, comm=None):
err = fast(args=cmdargs.split(), comm=None)

log.info(f'Rank {rank} done in {(time.time() - t1)/60.:.2f} min')
if err != 0:
if not os.path.exists(outfile):
log.warning(f'Rank {rank} missing {outfile}')
raise IOError
except:
log.warning(f'Rank {rank} raised an exception')
import traceback
traceback.print_exc()

if subcomm:
if subcomm.rank == 0:
log.debug(f'Rank {rank} is done')
elif rank == 0:
if rank == 0:
log.debug(f'Rank {rank} is done')

if comm:
Expand Down Expand Up @@ -348,7 +256,6 @@ def main():
parser.add_argument('--plan', action='store_true', help='Plan how many nodes to use and how to distribute the targets.')
parser.add_argument('--profile', action='store_true', help='Write out profiling / timing files..')
parser.add_argument('--nompi', action='store_true', help='Do not use MPI parallelism.')
parser.add_argument('--purempi', action='store_true', help='Use only MPI parallelism; no multiprocessing.')
parser.add_argument('--nolog', action='store_true', help='Do not write to the log file.')
parser.add_argument('--dry-run', action='store_true', help='Generate but do not run commands.')

Expand All @@ -373,8 +280,6 @@ def main():

if comm:
rank = comm.rank
if rank == 0 and args.purempi and comm.size > 1 and args.mp > 1 and comm.size < args.mp:
log.warning(f'Number of MPI tasks {comm.size} should be >{args.mp} for MPI parallelism.')
else:
rank = 0
# https://docs.nersc.gov/development/languages/python/parallel-python/#use-the-spawn-start-method
Expand Down
7 changes: 7 additions & 0 deletions doc/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ Change Log

*

3.1.3 (2025-02-01)
------------------

* Eliminate the use of subcommunicators in ``mpi-fastspecfit`` [`PR #210`_].

.. _`PR #210`: https://github.com/desihub/fastspecfit/pull/210

3.1.2 (2025-01-08)
------------------

Expand Down

0 comments on commit ce005d7

Please sign in to comment.