-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
config.yaml
304 lines (287 loc) · 6.18 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
defaults:
- _self_
- dset: musdb44
- svd: default
- variant: default
- override hydra/hydra_logging: colorlog
- override hydra/job_logging: colorlog
dummy:
dset:
musdb: /checkpoint/defossez/datasets/musdbhq
musdb_samplerate: 44100
use_musdb: true # set to false to not use musdb as training data.
wav: # path to custom wav dataset
wav2: # second custom wav dataset
segment: 11
shift: 1
train_valid: false
full_cv: true
samplerate: 44100
channels: 2
normalize: true
metadata: ./metadata
sources: ['drums', 'bass', 'other', 'vocals']
valid_samples: # valid dataset size
backend: null # if provided select torchaudio backend.
test:
save: False
best: True
workers: 2
every: 20
split: true
shifts: 1
overlap: 0.25
sdr: true
metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr
nonhq: # path to non hq MusDB for evaluation
epochs: 360
batch_size: 64
max_batches: # limit the number of batches per epoch, useful for debugging
# or if your dataset is gigantic.
optim:
lr: 3e-4
momentum: 0.9
beta2: 0.999
loss: l1 # l1 or mse
optim: adam
weight_decay: 0
clip_grad: 0
seed: 42
debug: false
valid_apply: true
flag:
save_every:
weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss.
augment:
shift_same: false
repitch:
proba: 0.2
max_tempo: 12
remix:
proba: 1
group_size: 4
scale:
proba: 1
min: 0.25
max: 1.25
flip: true
continue_from: # continue from other XP, give the XP Dora signature.
continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models.
pretrained_repo: # repo for pretrained model (default is official AWS)
continue_best: true
continue_opt: false
misc:
num_workers: 10
num_prints: 4
show: false
verbose: false
# List of decay for EMA at batch or epoch level, e.g. 0.999.
# Batch level EMA are kept on GPU for speed.
ema:
epoch: []
batch: []
use_train_segment: true # to remove
model_segment: # override the segment parameter for the model, usually 4 times the training segment.
model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter.
demucs: # see demucs/demucs.py for a detailed description
# Channels
channels: 64
growth: 2
# Main structure
depth: 6
rewrite: true
lstm_layers: 0
# Convolutions
kernel_size: 8
stride: 4
context: 1
# Activations
gelu: true
glu: true
# Normalization
norm_groups: 4
norm_starts: 4
# DConv residual branch
dconv_depth: 2
dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both.
dconv_comp: 4
dconv_attn: 4
dconv_lstm: 4
dconv_init: 1e-4
# Pre/post treatment
resample: true
normalize: false
# Weight init
rescale: 0.1
hdemucs: # see demucs/hdemucs.py for a detailed description
# Channels
channels: 48
channels_time:
growth: 2
# STFT
nfft: 4096
wiener_iters: 0
end_iters: 0
wiener_residual: false
cac: true
# Main structure
depth: 6
rewrite: true
hybrid: true
hybrid_old: false
# Frequency Branch
multi_freqs: []
multi_freqs_depth: 3
freq_emb: 0.2
emb_scale: 10
emb_smooth: true
# Convolutions
kernel_size: 8
stride: 4
time_stride: 2
context: 1
context_enc: 0
# normalization
norm_starts: 4
norm_groups: 4
# DConv residual branch
dconv_mode: 1
dconv_depth: 2
dconv_comp: 4
dconv_attn: 4
dconv_lstm: 4
dconv_init: 1e-3
# Weight init
rescale: 0.1
# Torchaudio implementation of HDemucs
torch_hdemucs:
# Channels
channels: 48
growth: 2
# STFT
nfft: 4096
# Main structure
depth: 6
freq_emb: 0.2
emb_scale: 10
emb_smooth: true
# Convolutions
kernel_size: 8
stride: 4
time_stride: 2
context: 1
context_enc: 0
# normalization
norm_starts: 4
norm_groups: 4
# DConv residual branch
dconv_depth: 2
dconv_comp: 4
dconv_attn: 4
dconv_lstm: 4
dconv_init: 1e-3
htdemucs: # see demucs/htdemucs.py for a detailed description
# Channels
channels: 48
channels_time:
growth: 2
# STFT
nfft: 4096
wiener_iters: 0
end_iters: 0
wiener_residual: false
cac: true
# Main structure
depth: 4
rewrite: true
# Frequency Branch
multi_freqs: []
multi_freqs_depth: 3
freq_emb: 0.2
emb_scale: 10
emb_smooth: true
# Convolutions
kernel_size: 8
stride: 4
time_stride: 2
context: 1
context_enc: 0
# normalization
norm_starts: 4
norm_groups: 4
# DConv residual branch
dconv_mode: 1
dconv_depth: 2
dconv_comp: 8
dconv_init: 1e-3
# Before the Transformer
bottom_channels: 0
# CrossTransformer
# ------ Common to all
# Regular parameters
t_layers: 5
t_hidden_scale: 4.0
t_heads: 8
t_dropout: 0.0
t_layer_scale: True
t_gelu: True
# ------------- Positional Embedding
t_emb: sin
t_max_positions: 10000 # for the scaled embedding
t_max_period: 10000.0
t_weight_pos_embed: 1.0
t_cape_mean_normalize: True
t_cape_augment: True
t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
t_sin_random_shift: 0
# ------------- norm before a transformer encoder
t_norm_in: True
t_norm_in_group: False
# ------------- norm inside the encoder
t_group_norm: False
t_norm_first: True
t_norm_out: True
# ------------- optim
t_weight_decay: 0.0
t_lr:
# ------------- sparsity
t_sparse_self_attn: False
t_sparse_cross_attn: False
t_mask_type: diag
t_mask_random_seed: 42
t_sparse_attn_window: 400
t_global_window: 100
t_sparsity: 0.95
t_auto_sparsity: False
# Cross Encoder First (False)
t_cross_first: False
# Weight init
rescale: 0.1
svd: # see svd.py for documentation
penalty: 0
min_size: 0.1
dim: 1
niters: 2
powm: false
proba: 1
conv_only: false
convtr: false
bs: 1
quant: # quantization hyper params
diffq: # diffq penalty, typically 1e-4 or 3e-4
qat: # use QAT with a fixed number of bits (not as good as diffq)
min_size: 0.2
group_size: 8
dora:
dir: outputs
exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend']
slurm:
time: 4320
constraint: volta32gb
setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6']
# Hydra config
hydra:
job_logging:
formatters:
colorlog:
datefmt: "%m-%d %H:%M:%S"