Skip to content

Commit

Permalink
FIX: '--user-dir' on multi-gpu (facebookresearch#449)
Browse files Browse the repository at this point in the history
Summary:
On a multi-gpu training scenario, the `train.py` script spawns new processes with `torch.multiprocessing.spawn`. Unfortunately those child processes don't inherit the modules imported with `--user-dir`.

This pull request fixes this problem: custom module import in now explicit on every `main()` function.
Pull Request resolved: facebookresearch#449

Differential Revision: D13676922

Pulled By: myleott

fbshipit-source-id: 520358d66155697885b878a37e7d0484bddbc1c6
  • Loading branch information
davidecaroselli authored and facebook-github-bot committed Jan 16, 2019
1 parent bdec179 commit 7853818
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 11 deletions.
3 changes: 3 additions & 0 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fairseq import options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module


class WordStat(object):
Expand Down Expand Up @@ -47,6 +48,8 @@ def __str__(self):
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'

import_user_module(parsed_args)

print(parsed_args)

use_cuda = torch.cuda.is_available() and not parsed_args.cpu
Expand Down
4 changes: 1 addition & 3 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def get_parser(desc, default_task='translation'):
usr_parser = argparse.ArgumentParser(add_help=False)
usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args()

if usr_args.user_dir is not None:
import_user_module(usr_args.user_dir)
import_user_module(usr_args)

parser = argparse.ArgumentParser()
# fmt: off
Expand Down
19 changes: 12 additions & 7 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,15 @@ def nullsafe_min(l):
return max_positions


def import_user_module(module_path):
module_path = os.path.abspath(module_path)
module_parent, module_name = os.path.split(module_path)

sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
def import_user_module(args):
if hasattr(args, 'user_dir'):
module_path = args.user_dir

if module_path is not None:
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path)

if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
3 changes: 3 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module


def main(args):
Expand All @@ -24,6 +25,8 @@ def main(args):
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'

import_user_module(args)

if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
Expand Down
4 changes: 3 additions & 1 deletion interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator

from fairseq.utils import import_user_module

Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
Expand Down Expand Up @@ -56,6 +56,8 @@ def make_batches(lines, args, task, max_positions):


def main(args):
import_user_module(args)

if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
Expand Down
4 changes: 4 additions & 0 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool

from fairseq.utils import import_user_module


def get_parser():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -66,6 +68,8 @@ def get_parser():


def main(args):
import_user_module(args)

print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
Expand Down
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.utils import import_user_module


def main(args):
import_user_module(args)

if args.max_tokens is None:
args.max_tokens = 6000
print(args)
Expand Down

0 comments on commit 7853818

Please sign in to comment.