Skip to content

Commit

Permalink
Centralize hydra init (and support packaged location of configs) (fac…
Browse files Browse the repository at this point in the history
…ebookresearch#2784)

Summary:
Configs can either be in `/fairseq/configs` (once the package is installed) or `/configs` (if using an editable installation). This centralizes the hydra init and supports these two possible config locations.

Pull Request resolved: facebookresearch#2784

Reviewed By: alexeib

Differential Revision: D24513586

Pulled By: myleott

fbshipit-source-id: 8e10a88177ebcf809d5d37d448d2b384142febef
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Oct 27, 2020
1 parent beeac0a commit 01be083
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 40 deletions.
4 changes: 4 additions & 0 deletions fairseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
sys.modules["fairseq.metrics"] = metrics
sys.modules["fairseq.progress_bar"] = progress_bar

# initialize hydra
from fairseq.dataclass.initialize import hydra_init
hydra_init()

import fairseq.criterions # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
Expand Down
4 changes: 2 additions & 2 deletions fairseq/dataclass/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def register_module_dataclass(
cs.store(name=k, group=group, node=node_, provider="fairseq")


def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs"""
def hydra_init() -> None:
cs = ConfigStore.instance()

for k in FairseqConfig.__dataclass_fields__:
v = FairseqConfig.__dataclass_fields__[k].default
Expand Down
19 changes: 11 additions & 8 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import ast
import os
from argparse import ArgumentError, ArgumentParser, Namespace
from dataclasses import _MISSING_TYPE, MISSING
from enum import Enum
Expand Down Expand Up @@ -272,19 +273,21 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:


def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
"""Convert a flat argparse.Namespace to a structured DictConfig."""

# Here we are using field values provided in args to override counterparts inside config object
overrides, deletes = override_module_args(args)

cfg_name = "config"
cfg_path = f"../../{cfg_name}"
# configs will be in fairseq/config after installation
config_path = os.path.join("..", "config")
if not os.path.exists(config_path):
# in case of "--editable" installs we need to go one dir up
config_path = os.path.join("..", "..", "config")

if not GlobalHydra().is_initialized():
initialize(config_path=cfg_path)

composed_cfg = compose(cfg_name, overrides=overrides, strict=False)
for k in deletes:
composed_cfg[k] = None
with initialize(config_path=config_path, strict=True):
composed_cfg = compose("config", overrides=overrides, strict=False)
for k in deletes:
composed_cfg[k] = None

cfg = OmegaConf.create(
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
Expand Down
6 changes: 0 additions & 6 deletions fairseq_cli/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.dataclass.initialize import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig


Expand Down Expand Up @@ -288,7 +285,4 @@ def cli_main():


if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()
6 changes: 0 additions & 6 deletions fairseq_cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
import torch
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.data import encoders
from fairseq.dataclass.initialize import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig


Expand Down Expand Up @@ -393,7 +390,4 @@ def cli_main():


if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()
6 changes: 0 additions & 6 deletions fairseq_cli/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.initialize import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize


logging.basicConfig(
Expand Down Expand Up @@ -322,7 +319,4 @@ def cli_main():


if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()
6 changes: 0 additions & 6 deletions fairseq_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy as np
import torch
from hydra.core.config_store import ConfigStore

from fairseq import (
checkpoint_utils,
Expand All @@ -31,8 +30,6 @@
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from omegaconf import DictConfig
from hydra.experimental import initialize
from fairseq.dataclass.initialize import register_hydra_cfg
from fairseq.trainer import Trainer


Expand Down Expand Up @@ -353,7 +350,4 @@ def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]]


if __name__ == '__main__':
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()
6 changes: 0 additions & 6 deletions fairseq_cli/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@

import torch
from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.dataclass.initialize import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import metrics, progress_bar
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig


Expand Down Expand Up @@ -140,7 +137,4 @@ def cli_main():


if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

0 comments on commit 01be083

Please sign in to comment.