diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index e59cac085c..807ad07db5 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -81,7 +81,7 @@ def build_model(cfg): """Build model Args: - cfg (AttrDict): Arch config. + cfg (DictConfig): Arch config. Returns: nn.Layer: Model. diff --git a/ppsci/constraint/__init__.py b/ppsci/constraint/__init__.py index 6cbe1a42b0..9179439436 100644 --- a/ppsci/constraint/__init__.py +++ b/ppsci/constraint/__init__.py @@ -42,7 +42,7 @@ def build_constraint(cfg, equation_dict, geom_dict): """Build constraint(s). Args: - cfg (List[AttrDict]): Constraint config list. + cfg (List[DictConfig]): Constraint config list. equation_dict (Dct[str, Equation]): Equation(s) in dict. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. diff --git a/ppsci/data/dataset/__init__.py b/ppsci/data/dataset/__init__.py index c0eebe860e..960e8a66b9 100644 --- a/ppsci/data/dataset/__init__.py +++ b/ppsci/data/dataset/__init__.py @@ -78,7 +78,7 @@ def build_dataset(cfg) -> "io.Dataset": """Build dataset Args: - cfg (List[AttrDict]): dataset config list. + cfg (List[DictConfig]): dataset config list. Returns: Dict[str, io.Dataset]: dataset. diff --git a/ppsci/equation/__init__.py b/ppsci/equation/__init__.py index 77a9b20860..2b97d378b7 100644 --- a/ppsci/equation/__init__.py +++ b/ppsci/equation/__init__.py @@ -54,7 +54,7 @@ def build_equation(cfg): """Build equation(s) Args: - cfg (List[AttrDict]): Equation(s) config list. + cfg (List[DictConfig]): Equation(s) config list. Returns: Dict[str, Equation]: Equation(s) in dict. diff --git a/ppsci/geometry/__init__.py b/ppsci/geometry/__init__.py index 4f1ff0b122..768ed0581d 100644 --- a/ppsci/geometry/__init__.py +++ b/ppsci/geometry/__init__.py @@ -54,7 +54,7 @@ def build_geometry(cfg): """Build geometry(ies) Args: - cfg (List[AttrDict]): Geometry config list. + cfg (List[DictConfig]): Geometry config list. Returns: Dict[str, Geometry]: Geometry(ies) in dict. diff --git a/ppsci/loss/__init__.py b/ppsci/loss/__init__.py index 0035a4193f..8bb9496f68 100644 --- a/ppsci/loss/__init__.py +++ b/ppsci/loss/__init__.py @@ -53,7 +53,7 @@ def build_loss(cfg): """Build loss. Args: - cfg (AttrDict): Loss config. + cfg (DictConfig): Loss config. Returns: Loss: Callable loss object. """ diff --git a/ppsci/loss/mtl/__init__.py b/ppsci/loss/mtl/__init__.py index 35f3b73d90..358efb3609 100644 --- a/ppsci/loss/mtl/__init__.py +++ b/ppsci/loss/mtl/__init__.py @@ -35,7 +35,7 @@ def build_mtl_aggregator(cfg): """Build loss aggregator with multi-task learning method. Args: - cfg (AttrDict): Aggregator config. + cfg (DictConfig): Aggregator config. Returns: Loss: Callable loss aggregator object. """ diff --git a/ppsci/metric/__init__.py b/ppsci/metric/__init__.py index 5390db4c4e..6059b22116 100644 --- a/ppsci/metric/__init__.py +++ b/ppsci/metric/__init__.py @@ -43,7 +43,7 @@ def build_metric(cfg): """Build metric. Args: - cfg (List[AttrDict]): List of metric config. + cfg (List[DictConfig]): List of metric config. Returns: Dict[str, Metric]: Dict of callable metric object. diff --git a/ppsci/optimizer/__init__.py b/ppsci/optimizer/__init__.py index c973b489fb..7dcf33b40b 100644 --- a/ppsci/optimizer/__init__.py +++ b/ppsci/optimizer/__init__.py @@ -39,7 +39,7 @@ def build_lr_scheduler(cfg, epochs, iters_per_epoch): """Build learning rate scheduler. Args: - cfg (AttrDict): Learning rate scheduler config. + cfg (DictConfig): Learning rate scheduler config. epochs (int): Total epochs. iters_per_epoch (int): Number of iterations of one epoch. @@ -57,7 +57,7 @@ def build_optimizer(cfg, model_list, epochs, iters_per_epoch): """Build optimizer and learning rate scheduler Args: - cfg (AttrDict): Learning rate scheduler config. + cfg (DictConfig): Learning rate scheduler config. model_list (Tuple[nn.Layer, ...]): Tuple of model(s). epochs (int): Total epochs. iters_per_epoch (int): Number of iterations of one epoch. diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index f7a00aa8fc..cde418ea57 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -158,12 +158,18 @@ def __init__( cfg: Optional[DictConfig] = None, ): self.cfg = cfg + if isinstance(cfg, DictConfig): + # (Recommended)Params can be passed within cfg + # rather than passed to 'Solver.__init__' one-by-one. + self._parse_params_from_cfg(cfg) + # set model self.model = model # set constraint self.constraint = constraint # set output directory - self.output_dir = output_dir + if not cfg: + self.output_dir = output_dir # set optimizer self.optimizer = optimizer @@ -192,19 +198,20 @@ def __init__( ) # set training hyper-parameter - self.epochs = epochs - self.iters_per_epoch = iters_per_epoch - # set update_freq for gradient accumulation - self.update_freq = update_freq - # set checkpoint saving frequency - self.save_freq = save_freq - # set logging frequency - self.log_freq = log_freq - - # set evaluation hyper-parameter - self.eval_during_train = eval_during_train - self.start_eval_epoch = start_eval_epoch - self.eval_freq = eval_freq + if not cfg: + self.epochs = epochs + self.iters_per_epoch = iters_per_epoch + # set update_freq for gradient accumulation + self.update_freq = update_freq + # set checkpoint saving frequency + self.save_freq = save_freq + # set logging frequency + self.log_freq = log_freq + + # set evaluation hyper-parameter + self.eval_during_train = eval_during_train + self.start_eval_epoch = start_eval_epoch + self.eval_freq = eval_freq # initialize training log(training loss, time cost, etc.) recorder during one epoch self.train_output_info: Dict[str, misc.AverageMeter] = {} @@ -221,21 +228,17 @@ def __init__( "reader_cost": misc.AverageMeter("reader_cost", ".5f", postfix="s"), } - # fix seed for reproducibility - self.seed = seed - # set running device - if device != "cpu" and paddle.device.get_device() == "cpu": + if not cfg: + self.device = device + if self.device != "cpu" and paddle.device.get_device() == "cpu": logger.warning(f"Set device({device}) to 'cpu' for only cpu available.") - device = "cpu" - self.device = paddle.set_device(device) + self.device = "cpu" + self.device = paddle.set_device(self.device) # set equations for physics-driven or data-physics hybrid driven task, such as PINN self.equation = equation - # set geometry for generating data - self.geom = {} if geom is None else geom - # set validator self.validator = validator @@ -243,24 +246,27 @@ def __init__( self.visualizer = visualizer # set automatic mixed precision(AMP) configuration - self.use_amp = use_amp - self.amp_level = amp_level + if not cfg: + self.use_amp = use_amp + self.amp_level = amp_level self.scaler = amp.GradScaler(True) if self.use_amp else None # whether calculate metrics by each batch during evaluation, mainly for memory efficiency - self.compute_metric_by_batch = compute_metric_by_batch + if not cfg: + self.compute_metric_by_batch = compute_metric_by_batch if validator is not None: for metric in itertools.chain( *[_v.metric.values() for _v in self.validator.values()] ): - if metric.keep_batch ^ compute_metric_by_batch: + if metric.keep_batch ^ self.compute_metric_by_batch: raise ValueError( f"{misc.typename(metric)}.keep_batch should be " - f"{compute_metric_by_batch} when compute_metric_by_batch=" - f"{compute_metric_by_batch}." + f"{self.compute_metric_by_batch} when compute_metric_by_batch=" + f"{self.compute_metric_by_batch}." ) # whether set `stop_gradient=True` for every Tensor if no differentiation involved during evaluation - self.eval_with_no_grad = eval_with_no_grad + if not cfg: + self.eval_with_no_grad = eval_with_no_grad self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -278,19 +284,20 @@ def __init__( # set moving average model(optional) self.ema_model = None if self.cfg and any(key in self.cfg.TRAIN for key in ["ema", "swa"]): - if "ema" in self.cfg.TRAIN: - self.avg_freq = self.cfg.TRAIN.ema.avg_freq + if "ema" in self.cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False): self.ema_model = ema.ExponentialMovingAverage( self.model, self.cfg.TRAIN.ema.decay ) - elif "swa" in self.cfg.TRAIN: - self.avg_freq = self.cfg.TRAIN.swa.avg_freq + elif "swa" in self.cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False): self.ema_model = ema.StochasticWeightAverage(self.model) # load pretrained model, usually used for transfer learning - self.pretrained_model_path = pretrained_model_path - if pretrained_model_path is not None: - save_load.load_pretrain(self.model, pretrained_model_path, self.equation) + if not cfg: + self.pretrained_model_path = pretrained_model_path + if self.pretrained_model_path is not None: + save_load.load_pretrain( + self.model, self.pretrained_model_path, self.equation + ) # initialize an dict for tracking best metric during training self.best_metric = { @@ -298,14 +305,16 @@ def __init__( "epoch": 0, } # load model checkpoint, usually used for resume training - if checkpoint_path is not None: - if pretrained_model_path is not None: + if not cfg: + self.checkpoint_path = checkpoint_path + if self.checkpoint_path is not None: + if self.pretrained_model_path is not None: logger.warning( "Detected 'pretrained_model_path' is given, weights in which might be" "overridden by weights loaded from given 'checkpoint_path'." ) loaded_metric = save_load.load_checkpoint( - checkpoint_path, + self.checkpoint_path, self.model, self.optimizer, self.scaler, @@ -366,7 +375,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set VisualDL tool self.vdl_writer = None - if use_vdl: + if not cfg: + self.use_vdl = use_vdl + if self.use_vdl: with misc.RankZeroOnly(self.rank) as is_master: if is_master: self.vdl_writer = vdl.LogWriter(osp.join(output_dir, "vdl")) @@ -377,7 +388,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set WandB tool self.wandb_writer = None - if use_wandb: + if not cfg: + self.use_wandb = use_wandb + if self.use_wandb: try: import wandb except ModuleNotFoundError: @@ -390,7 +403,9 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel: # set TensorBoardX tool self.tbd_writer = None - if use_tbd: + if not cfg: + self.use_tbd = use_tbd + if self.use_tbd: try: import tensorboardX except ModuleNotFoundError: @@ -984,3 +999,43 @@ def plot_loss_history( smooth_step=smooth_step, use_semilogy=use_semilogy, ) + + def _parse_params_from_cfg(self, cfg: DictConfig): + """ + Parse hyper-parameters from DictConfig. + """ + self.output_dir = cfg.output_dir + self.log_freq = cfg.log_freq + self.use_tbd = cfg.use_tbd + self.use_vdl = cfg.use_vdl + self.wandb_config = cfg.wandb_config + self.use_wandb = cfg.use_wandb + self.device = cfg.device + self.to_static = cfg.to_static + + self.use_amp = cfg.use_amp + self.amp_level = cfg.amp_level + + self.epochs = cfg.TRAIN.epochs + self.iters_per_epoch = cfg.TRAIN.iters_per_epoch + self.update_freq = cfg.TRAIN.update_freq + self.save_freq = cfg.TRAIN.save_freq + self.eval_during_train = cfg.TRAIN.eval_during_train + self.start_eval_epoch = cfg.TRAIN.start_eval_epoch + self.eval_freq = cfg.TRAIN.eval_freq + self.checkpoint_path = cfg.TRAIN.checkpoint_path + + if "ema" in cfg.TRAIN and cfg.TRAIN.ema.get("use_ema", False): + self.avg_freq = cfg.TRAIN.ema.avg_freq + elif "swa" in cfg.TRAIN and cfg.TRAIN.swa.get("use_swa", False): + self.avg_freq = cfg.TRAIN.swa.avg_freq + + self.compute_metric_by_batch = cfg.EVAL.compute_metric_by_batch + self.eval_with_no_grad = cfg.EVAL.eval_with_no_grad + + if cfg.mode == "train": + self.pretrained_model_path = cfg.TRAIN.pretrained_model_path + elif cfg.mode == "eval": + self.pretrained_model_path = cfg.EVAL.pretrained_model_path + elif cfg.mode in ["export", "infer"]: + self.pretrained_model_path = cfg.INFER.pretrained_model_path diff --git a/ppsci/utils/__init__.py b/ppsci/utils/__init__.py index 5b076fb3bb..f397f090ac 100644 --- a/ppsci/utils/__init__.py +++ b/ppsci/utils/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NOTE: Put config module import at the top level for register default config(s) in +# ConfigStore at the begining of ppsci +from ppsci.utils import config # isort:skip # noqa: F401 from ppsci.utils import ema from ppsci.utils import initializer from ppsci.utils import logger @@ -22,7 +25,6 @@ from ppsci.utils.checker import dynamic_import_to_globals from ppsci.utils.checker import run_check from ppsci.utils.checker import run_check_mesh -from ppsci.utils.config import AttrDict from ppsci.utils.expression import ExpressionSolver from ppsci.utils.misc import AverageMeter from ppsci.utils.misc import set_random_seed @@ -39,7 +41,6 @@ from ppsci.utils.writer import save_tecplot_file __all__ = [ - "AttrDict", "AverageMeter", "ExpressionSolver", "initializer", diff --git a/ppsci/utils/callbacks.py b/ppsci/utils/callbacks.py index e55a29130f..bcfbbd46bd 100644 --- a/ppsci/utils/callbacks.py +++ b/ppsci/utils/callbacks.py @@ -31,9 +31,10 @@ class InitCallback(Callback): """Callback class for: - 1. Parse config dict from given yaml file and check its validity, complete missing items by its' default values. + 1. Parse config dict from given yaml file and check its validity. 2. Fixing random seed to 'config.seed'. 3. Initialize logger while creating output directory(if not exist). + 4. Enable prim mode if specified. NOTE: This callback is mainly for reducing unnecessary duplicate code in each examples code when runing with hydra. @@ -60,8 +61,6 @@ class InitCallback(Callback): """ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: - # check given cfg using pre-defined pydantic schema in 'SolverConfig', error(s) will be raised - # if any checking failed at this step if importlib.util.find_spec("pydantic") is not None: from pydantic import ValidationError else: @@ -76,8 +75,6 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: # error(s) will be printed and exit program if any checking failed at this step try: _model_pydantic = config_module.SolverConfig(**dict(config)) - # complete missing items with default values pre-defined in pydantic schema in - # 'SolverConfig' full_cfg = DictConfig(_model_pydantic.model_dump()) except ValidationError as e: print(e) @@ -100,7 +97,7 @@ def on_job_start(self, config: DictConfig, **kwargs: Any) -> None: # enable prim if specified if "prim" in full_cfg and bool(full_cfg.prim): - # Mostly for dy2st running, will be removed in the future + # Mostly for compiler running with dy2st. from paddle.framework import core core.set_prim_eager_enabled(True) diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index af28f2e207..0352d2f7ce 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -14,30 +14,77 @@ from __future__ import annotations -import argparse -import copy import importlib.util -import os from typing import Mapping from typing import Optional from typing import Tuple -import yaml -from paddle import static from typing_extensions import Literal -from ppsci.utils import logger -from ppsci.utils import misc - -__all__ = ["get_config", "replace_shape_with_inputspec_", "AttrDict"] +__all__ = [] if importlib.util.find_spec("pydantic") is not None: + from hydra.core.config_store import ConfigStore + from omegaconf import OmegaConf from pydantic import BaseModel from pydantic import field_validator + from pydantic import model_validator from pydantic_core.core_schema import ValidationInfo __all__.append("SolverConfig") + class EMAConfig(BaseModel): + use_ema: bool = False + decay: float = 0.9 + avg_freq: int = 1 + + @field_validator("decay") + def decay_check(cls, v): + if v <= 0 or v >= 1: + raise ValueError( + f"'decay' should be in (0, 1) when is type of float, but got {v}" + ) + return v + + @field_validator("avg_freq") + def avg_freq_check(cls, v): + if v <= 0: + raise ValueError( + "'avg_freq' should be a positive integer when is type of int, " + f"but got {v}" + ) + return v + + class SWAConfig(BaseModel): + use_swa: bool = False + avg_freq: int = 1 + avg_range: Optional[Tuple[int, int]] = None + + @field_validator("avg_range") + def avg_range_check(cls, v, info: ValidationInfo): + if isinstance(v, tuple) and v[0] > v[1]: + raise ValueError(f"'avg_range' should be a valid range, but got {v}.") + if isinstance(v, tuple) and v[0] < 0: + raise ValueError( + "The start epoch of 'avg_range' should be a non-negtive integer" + f" , but got {v[0]}." + ) + if isinstance(v, tuple) and v[1] > info.data["epochs"]: + raise ValueError( + "The end epoch of 'avg_range' should not be lager than " + f"'epochs'({info.data['epochs']}), but got {v[1]}." + ) + return v + + @field_validator("avg_freq") + def avg_freq_check(cls, v): + if v <= 0: + raise ValueError( + "'avg_freq' should be a positive integer when is type of int, " + f"but got {v}" + ) + return v + class TrainConfig(BaseModel): """ Schema of training config for pydantic validation. @@ -55,58 +102,6 @@ class TrainConfig(BaseModel): ema: Optional[EMAConfig] = None swa: Optional[SWAConfig] = None - class EMAConfig(BaseModel): - decay: float = 0.9 - avg_freq: int = 1 - - @field_validator("decay") - def decay_check(cls, v): - if v <= 0 or v >= 1: - raise ValueError( - f"'decay' should be in (0, 1) when is type of float, but got {v}" - ) - return v - - @field_validator("avg_freq") - def avg_freq_check(cls, v): - if v <= 0: - raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " - f"but got {v}" - ) - return v - - class SWAConfig(BaseModel): - avg_freq: int = 1 - avg_range: Optional[Tuple[int, int]] = None - - @field_validator("avg_range") - def avg_range_check(cls, v, info: ValidationInfo): - if v[0] > v[1]: - raise ValueError( - f"'avg_range' should be a valid range, but got {v}." - ) - if v[0] < 0: - raise ValueError( - "The start epoch of 'avg_range' should be a non-negtive integer" - f" , but got {v[0]}." - ) - if v[1] > info.data["epochs"]: - raise ValueError( - "The end epoch of 'avg_range' should not be lager than " - f"'epochs'({info.data['epochs']}), but got {v[1]}." - ) - return v - - @field_validator("avg_freq") - def avg_freq_check(cls, v): - if v <= 0: - raise ValueError( - "'avg_freq' should be a positive integer when is type of int, " - f"but got {v}" - ) - return v - # Fine-grained validator(s) below @field_validator("epochs") def epochs_check(cls, v): @@ -164,21 +159,14 @@ def eval_freq_check(cls, v, info: ValidationInfo): ) return v - @field_validator("ema") - def ema_check(cls, v, info: ValidationInfo): - if "swa" in info.data and info.data["swa"] is not None: - raise ValueError( - "The config of 'swa' should not be used when 'ema' is specifed." - ) - return v - - @field_validator("swa") - def swa_check(cls, v, info: ValidationInfo): - if "ema" in info.data and info.data["ema"] is not None: + @model_validator(mode="after") + def ema_swa_checker(self): + if (self.ema and self.swa) and (self.ema.use_ema and self.swa.use_swa): raise ValueError( - "The config of 'ema' should not be used when 'swa' is specifed." + "Cannot enable both EMA and SWA at the same time, " + "please disable at least one of them." ) - return v + return self class EvalConfig(BaseModel): """ @@ -195,7 +183,7 @@ class InferConfig(BaseModel): """ pretrained_model_path: Optional[str] = None - export_path: str + export_path: str = "./inference" pdmodel_path: Optional[str] = None pdiparams_path: Optional[str] = None onnx_path: Optional[str] = None @@ -284,8 +272,9 @@ class SolverConfig(BaseModel): log_freq: int = 20 seed: int = 42 use_vdl: bool = False - use_wandb: bool = False + use_tbd: bool = False wandb_config: Optional[Mapping] = None + use_wandb: bool = False device: Literal["cpu", "gpu", "xpu"] = "gpu" use_amp: bool = False amp_level: Literal["O0", "O1", "O2", "OD"] = "O1" @@ -320,195 +309,99 @@ def seed_check(cls, v): @field_validator("use_wandb") def use_wandb_check(cls, v, info: ValidationInfo): - if not isinstance(info.data["wandb_config"], dict): + if v and not isinstance(info.data["wandb_config"], dict): raise ValueError( "'wandb_config' should be a dict when 'use_wandb' is True, " - f"but got {misc.typename(info.data['wandb_config'])}" + f"but got {info.data['wandb_config'].__class__.__name__}" ) return v - -class AttrDict(dict): - def __getattr__(self, key): - return self[key] - - def __setattr__(self, key, value): - if key in self.__dict__: - self.__dict__[key] = value - else: - self[key] = value - - def __deepcopy__(self, content): - return AttrDict(copy.deepcopy(dict(self))) - - -def create_attr_dict(yaml_config): - from ast import literal_eval - - for key, value in yaml_config.items(): - if isinstance(value, dict): - yaml_config[key] = value = AttrDict(value) - if isinstance(value, str): - try: - value = literal_eval(value) - except BaseException: - pass - if isinstance(value, AttrDict): - create_attr_dict(yaml_config[key]) - else: - yaml_config[key] = value - - -def parse_config(cfg_file): - """Load a config file into AttrDict""" - with open(cfg_file, "r") as fopen: - yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader)) - create_attr_dict(yaml_config) - return yaml_config - - -def print_dict(d, delimiter=0): - """ - Recursively visualize a dict and - indenting according by the relationship of keys. - """ - placeholder = "-" * 60 - for k, v in d.items(): - if isinstance(v, dict): - logger.info(f"{delimiter * ' '}{k} : ") - print_dict(v, delimiter + 4) - elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): - logger.info(f"{delimiter * ' '}{k} : ") - for value in v: - print_dict(value, delimiter + 2) - else: - logger.info(f"{delimiter * ' '}{k} : {v}") - - if k[0].isupper() and delimiter == 0: - logger.info(placeholder) - - -def print_config(config): - """ - Visualize configs - Arguments: - config: configs - """ - logger.advertise() - print_dict(config) - - -def override(dl, ks, v): + # Register 'XXXConfig' as default node, so as to be used as default config in *.yaml """ - Recursively replace dict of list - Args: - dl(dict or list): dict or list to be replaced - ks(list): list of keys - v(str): value to be replaced + #### xxx.yaml #### + defaults: + - ppsci_default <-- 'ppsci_default' used here + - TRAIN: train_default <-- 'train_default' used here + - TRAIN/ema: ema_default <-- 'ema_default' used here + - TRAIN/swa: swa_default <-- 'swa_default' used here + - EVAL: eval_default <-- 'eval_default' used here + - INFER: infer_default <-- 'infer_default' used here + - _self_ + mode: train + seed: 42 + ... + ... + ################## """ - def str2num(v): - try: - return eval(v) - except Exception: - return v - - if not isinstance(dl, (list, dict)): - raise ValueError(f"{dl} should be a list or a dict") - if len(ks) <= 0: - raise ValueError("length of keys should be larger than 0") - - if isinstance(dl, list): - k = str2num(ks[0]) - if len(ks) == 1: - if k >= len(dl): - raise ValueError(f"index({k}) out of range({dl})") - dl[k] = str2num(v) - else: - override(dl[k], ks[1:], v) - else: - if len(ks) == 1: - # assert ks[0] in dl, (f"{ks[0]} is not exist in {dl}") - if ks[0] not in dl: - print(f"A new field ({ks[0]}) detected!") - dl[ks[0]] = str2num(v) - else: - if ks[0] not in dl.keys(): - dl[ks[0]] = {} - print(f"A new Series field ({ks[0]}) detected!") - override(dl[ks[0]], ks[1:], v) - - -def override_config(config, options=None): - """ - Recursively override the config - Args: - config(dict): dict to be replaced - options(list): list of pairs(key0.key1.idx.key2=value) - such as: [ - "topk=2", - "VALID.transforms.1.ResizeImage.resize_short=300" - ] - Returns: - config(dict): replaced config - """ - if options is not None: - for opt in options: - assert isinstance(opt, str), f"option({opt}) should be a str" - assert ( - "=" in opt - ), f"option({opt}) should contain a = to distinguish between key and value" - pair = opt.split("=") - assert len(pair) == 2, "there can be only a = in the option" - key, value = pair - keys = key.split(".") - override(config, keys, value) - return config - - -def get_config(fname, overrides=None, show=False): - """ - Read config from file - """ - if not os.path.exists(fname): - raise FileNotFoundError(f"config file({fname}) is not exist") - config = parse_config(fname) - override_config(config, overrides) - if show: - print_config(config) - return config - - -def parse_args(): - parser = argparse.ArgumentParser("paddlescience running script") - parser.add_argument("-e", "--epochs", type=int, help="training epochs") - parser.add_argument("-o", "--output_dir", type=str, help="output directory") - parser.add_argument( - "--to_static", - action="store_true", - help="whether enable to_static for forward computation", + cs = ConfigStore.instance() + + global_default_cfg = SolverConfig().model_dump() + omegaconf_dict_config = OmegaConf.create(global_default_cfg) + cs.store(name="ppsci_default", node=omegaconf_dict_config) + + train_default_cfg = TrainConfig().model_dump() + train_omegaconf_dict_config = OmegaConf.create(train_default_cfg) + cs.store(group="TRAIN", name="train_default", node=train_omegaconf_dict_config) + + ema_default_cfg = EMAConfig().model_dump() + ema_omegaconf_dict_config = OmegaConf.create(ema_default_cfg) + cs.store(group="TRAIN/ema", name="ema_default", node=ema_omegaconf_dict_config) + + swa_default_cfg = SWAConfig().model_dump() + swa_omegaconf_dict_config = OmegaConf.create(swa_default_cfg) + cs.store(group="TRAIN/swa", name="swa_default", node=swa_omegaconf_dict_config) + + eval_default_cfg = EvalConfig().model_dump() + eval_omegaconf_dict_config = OmegaConf.create(eval_default_cfg) + cs.store(group="EVAL", name="eval_default", node=eval_omegaconf_dict_config) + + infer_default_cfg = InferConfig().model_dump() + infer_omegaconf_dict_config = OmegaConf.create(infer_default_cfg) + cs.store(group="INFER", name="infer_default", node=infer_omegaconf_dict_config) + + exclude_keys_default = [ + "mode", + "output_dir", + "log_freq", + "seed", + "use_vdl", + "use_tbd", + "wandb_config", + "use_wandb", + "device", + "use_amp", + "amp_level", + "to_static", + "prim", + "log_level", + "TRAIN.save_freq", + "TRAIN.eval_during_train", + "TRAIN.start_eval_epoch", + "TRAIN.eval_freq", + "TRAIN.checkpoint_path", + "TRAIN.pretrained_model_path", + "EVAL.pretrained_model_path", + "EVAL.eval_with_no_grad", + "EVAL.compute_metric_by_batch", + "INFER.pretrained_model_path", + "INFER.export_path", + "INFER.pdmodel_path", + "INFER.pdiparams_path", + "INFER.onnx_path", + "INFER.device", + "INFER.engine", + "INFER.precision", + "INFER.ir_optim", + "INFER.min_subgraph_size", + "INFER.gpu_mem", + "INFER.gpu_id", + "INFER.max_batch_size", + "INFER.num_cpu_threads", + "INFER.batch_size", + ] + cs.store( + group="hydra/job/config/override_dirname/exclude_keys", + name="exclude_keys_default", + node=exclude_keys_default, ) - - args = parser.parse_args() - return args - - -def _is_num_seq(seq): - # whether seq is all int number(it is a shape) - return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq) - - -def replace_shape_with_inputspec_(node: AttrDict): - if _is_num_seq(node): - return True - - if isinstance(node, dict): - for key in node: - if replace_shape_with_inputspec_(node[key]): - node[key] = static.InputSpec(node[key]) - elif isinstance(node, list): - for i in range(len(node)): - if replace_shape_with_inputspec_(node[i]): - node[i] = static.InputSpec(node[i]) - - return False diff --git a/ppsci/validate/__init__.py b/ppsci/validate/__init__.py index 9e05b13665..3bc1c9ae4d 100644 --- a/ppsci/validate/__init__.py +++ b/ppsci/validate/__init__.py @@ -33,7 +33,7 @@ def build_validator(cfg, equation_dict, geom_dict): """Build validator(s). Args: - cfg (List[AttrDict]): Validator(s) config list. + cfg (List[DictConfig]): Validator(s) config list. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. equation_dict (Dct[str, Equation]): Equation(s) in dict. diff --git a/ppsci/visualize/__init__.py b/ppsci/visualize/__init__.py index 7beea234c5..73cd0e0953 100644 --- a/ppsci/visualize/__init__.py +++ b/ppsci/visualize/__init__.py @@ -55,7 +55,7 @@ def build_visualizer(cfg): """Build visualizer(s). Args: - cfg (List[AttrDict]): Visualizer(s) config list. + cfg (List[DictConfig]): Visualizer(s) config list. geom_dict (Dct[str, Geometry]): Geometry(ies) in dict. equation_dict (Dct[str, Equation]): Equation(s) in dict. diff --git a/test/utils/test_config.py b/test/utils/test_config.py index 5f650685c8..844d1f449f 100644 --- a/test/utils/test_config.py +++ b/test/utils/test_config.py @@ -1,21 +1,11 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import os import hydra import paddle import pytest -from omegaconf import DictConfig +import yaml + +from ppsci.utils.callbacks import InitCallback paddle.seed(1024) @@ -28,42 +18,33 @@ (10, "eval", -1), ], ) -def test_invalid_epochs( - epochs, - mode, - seed, -): - @hydra.main(version_base=None, config_path="./", config_name="test_config.yaml") - def main(cfg: DictConfig): - pass - - # sys.exit will be called when validation error in pydantic, so there we use - # SystemExit instead of other type of errors. - with pytest.raises(SystemExit): - cfg_dict = dict( - { - "TRAIN": { - "epochs": epochs, - }, - "mode": mode, - "seed": seed, - "hydra": { - "callbacks": { - "init_callback": { - "_target_": "ppsci.utils.callbacks.InitCallback" - } - } - }, +def test_invalid_epochs(tmpdir, epochs, mode, seed): + cfg_dict = { + "hydra": { + "callbacks": { + "init_callback": {"_target_": "ppsci.utils.callbacks.InitCallback"} } - ) - # print(cfg_dict) - import yaml + }, + "mode": mode, + "seed": seed, + "TRAIN": { + "epochs": epochs, + }, + } + + dir_ = os.path.dirname(__file__) + config_abs_path = os.path.join(dir_, "test_config.yaml") + with open(config_abs_path, "w") as f: + f.write(yaml.dump(cfg_dict)) - with open("test_config.yaml", "w") as f: - yaml.dump(dict(cfg_dict), f) + with hydra.initialize(config_path="./", version_base=None): + cfg = hydra.compose(config_name="test_config.yaml") - main() + with pytest.raises(SystemExit) as exec_info: + InitCallback().on_job_start(config=cfg) + assert exec_info.value.code == 2 +# 这部分通常不需要,除非你想直接从脚本运行测试 if __name__ == "__main__": pytest.main() diff --git a/test/utils/test_writer.py b/test/utils/test_writer.py index 6e960bee28..cce3f69ab8 100644 --- a/test/utils/test_writer.py +++ b/test/utils/test_writer.py @@ -21,13 +21,11 @@ def test_save_csv_file(): keys = ["x1", "y1", "z1"] - alias_dict = ( - { - "x": "x1", - "y": "y1", - "z": "z1", - }, - ) + alias_dict = { + "x": "x1", + "y": "y1", + "z": "z1", + } data_dict = { keys[0]: np.random.randint(0, 255, (10, 1)), keys[1]: np.random.rand(10, 1),