|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. |
| 2 | +# SPDX-FileCopyrightText: All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +hydra: |
| 18 | + job: |
| 19 | + chdir: true |
| 20 | + name: patched_diffusion_opt |
| 21 | + run: |
| 22 | + dir: ./output/${hydra:job.name} |
| 23 | + searchpath: |
| 24 | + - pkg://conf/base # Do not modify |
| 25 | + |
| 26 | +# Base parameters for dataset, model, training, and validation |
| 27 | +defaults: |
| 28 | + |
| 29 | + - dataset: hrrr_corrdiff_synthetic |
| 30 | + # The dataset type for training. |
| 31 | + # Accepted values: |
| 32 | + # `gefs_hrrr`: full GEFS-HRRR dataset for continental US. |
| 33 | + # `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments. |
| 34 | + # `cwb`: full CWB dataset for Taiwan. |
| 35 | + # `custom`: user-defined dataset. Parameters need to be specified below. |
| 36 | + |
| 37 | + - model: patched_diffusion |
| 38 | + # The model type. |
| 39 | + # Accepted values: |
| 40 | + # `regression`: a regression UNet for deterministic predictions |
| 41 | + # `lt_aware_ce_regression`: similar to `regression` but with lead time |
| 42 | + # conditioning |
| 43 | + # `diffusion`: a diffusion UNet for residual predictions |
| 44 | + # `patched_diffusion`: a more memory-efficient diffusion model |
| 45 | + # `lt_aware_patched_diffusion`: similar to `patched_diffusion` but |
| 46 | + # with lead time conditioning |
| 47 | + |
| 48 | + - model_size: normal |
| 49 | + # The model size configuration. |
| 50 | + # Accepted values: |
| 51 | + # `normal`: normal model size |
| 52 | + # `mini`: smaller model size for fast experiments |
| 53 | + |
| 54 | + - training: ${model} |
| 55 | + # The base training parameters. Determined by the model type. |
| 56 | + |
| 57 | + |
| 58 | +# Dataset parameters. Used for `custom` dataset type. |
| 59 | +# Modify or add below parameters that should be passed as argument to the |
| 60 | +# user-defined dataset class. |
| 61 | +dataset: |
| 62 | + data_path: ./data |
| 63 | + # Path to .nc data file |
| 64 | + stats_path: ./data/stats.json |
| 65 | + # Path to json stats file |
| 66 | + |
| 67 | +# Training parameters |
| 68 | +training: |
| 69 | + hp: |
| 70 | + training_duration: 200000000 |
| 71 | + # Training duration based on the number of processed samples |
| 72 | + total_batch_size: 512 |
| 73 | + # Total batch size |
| 74 | + batch_size_per_gpu: 4 |
| 75 | + |
| 76 | + patch_shape_x: 448 |
| 77 | + patch_shape_y: 448 |
| 78 | + # Patch size. Patch training is used if these dimensions differ from |
| 79 | + # img_shape_x and img_shape_y. |
| 80 | + patch_num: 16 |
| 81 | + # Number of patches from a single sample. Total number of patches is |
| 82 | + # patch_num * batch_size_global. |
| 83 | + max_patch_per_gpu: 9 |
| 84 | + # Maximum number of pataches a gpu can hold |
| 85 | + |
| 86 | + lr: 0.0002 |
| 87 | + # Learning rate |
| 88 | + grad_clip_threshold: 1e6 |
| 89 | + lr_decay: 0.7 |
| 90 | + lr_rampup: 1000000 |
| 91 | + |
| 92 | + # Performance |
| 93 | + perf: |
| 94 | + fp_optimizations: amp-bf16 |
| 95 | + # Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"] |
| 96 | + # "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16} |
| 97 | + dataloader_workers: 4 |
| 98 | + # DataLoader worker processes |
| 99 | + songunet_checkpoint_level: 0 # 0 means no checkpointing |
| 100 | + # Gradient checkpointing level, value is number of layers to checkpoint |
| 101 | + # optimization_mode: True |
| 102 | + use_apex_gn: True |
| 103 | + torch_compile: True |
| 104 | + profile_mode: False |
| 105 | + |
| 106 | + io: |
| 107 | + regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus |
| 108 | + # Path to load the regression checkpoint |
| 109 | + |
| 110 | + # Where to load the regression checkpoint |
| 111 | + print_progress_freq: 1000 |
| 112 | + # How often to print progress |
| 113 | + save_checkpoint_freq: 500000 |
| 114 | + # How often to save the checkpoints, measured in number of processed samples |
| 115 | + validation_freq: 5000 |
| 116 | + # how often to record the validation loss, measured in number of processed samples |
| 117 | + validation_steps: 10 |
| 118 | + # how many loss evaluations are used to compute the validation loss per checkpoint |
| 119 | + |
| 120 | +# Parameters for wandb logging |
| 121 | +wandb: |
| 122 | + mode: offline |
| 123 | + # Configure whether to use wandb: "offline", "online", "disabled" |
| 124 | + results_dir: "./wandb" |
| 125 | + # Directory to store wandb results |
| 126 | + watch_model: false |
| 127 | + # If true, wandb will track model parameters and gradients |
0 commit comments