Skip to content

Commit 1a6288d

Browse files
LostnEkkoAlicia Suijialusui1102CharlelieLrt
authored
Various Corrdiff optimizations for drastic increase of training efficiency (#809)
* mult-gpu training supported corrdiff optimization * enable mixed precision for val * clean codebase for opt * add amp_mode aware model architecture * add None checking for params * revise datatype casting schema * Add test cases for corrdiff optimizations Signed-off-by: Neal Pan <[email protected]> * revised from_checkpoint, update tests and CHANGELOG Signed-off-by: jialusui1102 <[email protected]> * Lint and format code properly Signed-off-by: Neal Pan <[email protected]> * add multi-gpu optimization * rebase changes and update tests and configs Signed-off-by: jialusui1102 <[email protected]> * merge ResidualLoss and refactored layer and Unet init based on PR review Signed-off-by: jialusui1102 <[email protected]> * Update layers.py with robust apex import * address incompatibility between dynamo and patching, retain same optimization perf w torch.compile Signed-off-by: jialusui1102 <[email protected]> * update tests Signed-off-by: jialusui1102 <[email protected]> * update changelog Signed-off-by: jialusui1102 <[email protected]> * initialize global_index directly on device Signed-off-by: jialusui1102 <[email protected]> * formatting Signed-off-by: jialusui1102 <[email protected]> --------- Signed-off-by: Neal Pan <[email protected]> Signed-off-by: jialusui1102 <[email protected]> Co-authored-by: Alicia Sui <[email protected]> Co-authored-by: jialusui1102 <[email protected]> Co-authored-by: Charlelie Laurent <[email protected]>
1 parent 37b1da0 commit 1a6288d

File tree

16 files changed

+1859
-443
lines changed

16 files changed

+1859
-443
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- ERA5 download example updated to use current file format convention and
2424
restricts global statistics computation to the training set
2525
- Support for training custom StormCast models and various other improvements for StormCast
26+
- Updated CorrDiff training code to support multiple patch iterations to amortize
27+
regression cost and usage of `torch.compile`
28+
- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type
29+
casting workflow, avoiding unnecessary casting under autocast mode
30+
- Refactored Conv2d to enable fusion of conv2d with bias addition
31+
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
32+
Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
33+
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
34+
- Updated `from_checkpoint` to accommodate conversion between Apex optimized ckp
35+
and non-optimized ckp
36+
- Refactored CorrDiff NVTX annotation workflow to be configurable
37+
- Refactored `ResidualLoss` to support patch-accumlating training for
38+
amortizing regression costs
2639

2740
### Deprecated
2841

examples/generative/corrdiff/conf/base/model/patched_diffusion.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
name: diffusion
17+
name: patched_diffusion
1818
# Model type.
1919
hr_mean_conditioning: True
2020
# Recommended to use high-res conditioning for diffusion.

examples/generative/corrdiff/conf/base/model_size/normal.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ model_args:
2323
# Per-resolution multipliers for the number of channels.
2424
channel_mult: [1, 2, 2, 2, 2]
2525
# Resolutions at which self-attention layers are applied.
26-
attention_levels: [28]
26+
attn_resolutions: [28]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
hydra:
18+
job:
19+
chdir: true
20+
name: patched_diffusion_opt
21+
run:
22+
dir: ./output/${hydra:job.name}
23+
searchpath:
24+
- pkg://conf/base # Do not modify
25+
26+
# Base parameters for dataset, model, training, and validation
27+
defaults:
28+
29+
- dataset: hrrr_corrdiff_synthetic
30+
# The dataset type for training.
31+
# Accepted values:
32+
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
33+
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
34+
# `cwb`: full CWB dataset for Taiwan.
35+
# `custom`: user-defined dataset. Parameters need to be specified below.
36+
37+
- model: patched_diffusion
38+
# The model type.
39+
# Accepted values:
40+
# `regression`: a regression UNet for deterministic predictions
41+
# `lt_aware_ce_regression`: similar to `regression` but with lead time
42+
# conditioning
43+
# `diffusion`: a diffusion UNet for residual predictions
44+
# `patched_diffusion`: a more memory-efficient diffusion model
45+
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
46+
# with lead time conditioning
47+
48+
- model_size: normal
49+
# The model size configuration.
50+
# Accepted values:
51+
# `normal`: normal model size
52+
# `mini`: smaller model size for fast experiments
53+
54+
- training: ${model}
55+
# The base training parameters. Determined by the model type.
56+
57+
58+
# Dataset parameters. Used for `custom` dataset type.
59+
# Modify or add below parameters that should be passed as argument to the
60+
# user-defined dataset class.
61+
dataset:
62+
data_path: ./data
63+
# Path to .nc data file
64+
stats_path: ./data/stats.json
65+
# Path to json stats file
66+
67+
# Training parameters
68+
training:
69+
hp:
70+
training_duration: 200000000
71+
# Training duration based on the number of processed samples
72+
total_batch_size: 512
73+
# Total batch size
74+
batch_size_per_gpu: 4
75+
76+
patch_shape_x: 448
77+
patch_shape_y: 448
78+
# Patch size. Patch training is used if these dimensions differ from
79+
# img_shape_x and img_shape_y.
80+
patch_num: 16
81+
# Number of patches from a single sample. Total number of patches is
82+
# patch_num * batch_size_global.
83+
max_patch_per_gpu: 9
84+
# Maximum number of pataches a gpu can hold
85+
86+
lr: 0.0002
87+
# Learning rate
88+
grad_clip_threshold: 1e6
89+
lr_decay: 0.7
90+
lr_rampup: 1000000
91+
92+
# Performance
93+
perf:
94+
fp_optimizations: amp-bf16
95+
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
96+
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
97+
dataloader_workers: 4
98+
# DataLoader worker processes
99+
songunet_checkpoint_level: 0 # 0 means no checkpointing
100+
# Gradient checkpointing level, value is number of layers to checkpoint
101+
# optimization_mode: True
102+
use_apex_gn: True
103+
torch_compile: True
104+
profile_mode: False
105+
106+
io:
107+
regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus
108+
# Path to load the regression checkpoint
109+
110+
# Where to load the regression checkpoint
111+
print_progress_freq: 1000
112+
# How often to print progress
113+
save_checkpoint_freq: 500000
114+
# How often to save the checkpoints, measured in number of processed samples
115+
validation_freq: 5000
116+
# how often to record the validation loss, measured in number of processed samples
117+
validation_steps: 10
118+
# how many loss evaluations are used to compute the validation loss per checkpoint
119+
120+
# Parameters for wandb logging
121+
wandb:
122+
mode: offline
123+
# Configure whether to use wandb: "offline", "online", "disabled"
124+
results_dir: "./wandb"
125+
# Directory to store wandb results
126+
watch_model: false
127+
# If true, wandb will track model parameters and gradients

0 commit comments

Comments
 (0)