-
Notifications
You must be signed in to change notification settings - Fork 10
/
NoBiasDecay.py
24 lines (21 loc) · 1.19 KB
/
NoBiasDecay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn
def noBiasDecay(model, lr, weight_decay):
'''
no bias decay : only apply weight decay to the weights in convolution and fully-connected layers
In paper [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187)
Ref: https://github.com/weiaicunzai/Bag_of_Tricks_for_Image_Classification_with_Convolutional_Neural_Networks/blob/master/utils.py
'''
decay, bias_no_decay, weight_no_decay = [], [], []
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
decay.append(m.weight)
if m.bias is not None:
bias_no_decay.append(m.bias)
else:
if hasattr(m, 'weight'):
weight_no_decay.append(m.weight)
if hasattr(m, 'bias'):
bias_no_decay.append(m.bias)
assert len(list(model.parameters())) == len(decay) + len(bias_no_decay) + len(weight_no_decay)
# bias using 2*lr
return [{'params': bias_no_decay, 'lr': 2*lr, 'weight_decay': 0.0}, {'params': weight_no_decay, 'lr': lr, 'weight_decay': 0.0}, {'params': decay, 'lr': lr, 'weight_decay': weight_decay}]