-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsolver.py
145 lines (121 loc) · 5.65 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# import torch.utils.data.distributed
from misc.utils import to_cuda
from misc.utils import copy_file_to_screen
import warnings
import torch
import copy
import os
from misc.solver_utils import Solver_Utils
from munch import Munch
from data_loader import get_loader
warnings.filterwarnings('ignore')
class Solver(Solver_Utils):
def __init__(self, args, data_loader):
# Data loader
self.data_loader = data_loader
if args.mode == 'train':
self.data_loader_val = get_loader(args,
all_attr=args.ALL_ATTR,
shuffling=True,
mode='val')
self.args = args
self.args.data_module = self.data_loader.dataset
self.args.domains = data_loader.dataset.selected_attrs
self.args.n_domains = len(data_loader.dataset.selected_attrs)
self.args.c_dim = self.args.n_domains
self.dist = args.dist
self.verbose = 1 if self.dist.rank() == 0 else 0
self.args.num_domains = len(self.args.data_module.parent_attrs.keys())
# if self.verbose:
self.build_model()
self.build_solver()
if args.mode == 'train' and args.GPU[0] != '-1' and self.dist.rank(
) == 0:
copy_file_to_screen(args.sample_path, args.GPU)
self.dist.barrier()
# ==================================================================#
# ==================================================================#
def build_model(self):
from models.generator import Generator
from models.discriminator import Discriminator
from models.mapping import Noise2Style
from models.encoder import StyleEncoder
if self.args.TRAIN_MASK:
in_dim = self.args.mask_dim
else:
in_dim = self.args.color_dim
debug = self.args.mode == 'train' and self.verbose
discriminator = Discriminator(self.args, color_dim=in_dim, debug=debug)
style_encoder = StyleEncoder(self.args, color_dim=in_dim, debug=debug)
mapping_network = Noise2Style(self.args, debug=debug)
generator = Generator(self.args, color_dim=in_dim, debug=debug)
self.nets = Munch(
G=generator,
F=mapping_network,
S=style_encoder,
D=discriminator,
)
if debug:
self.print_network(self.nets.D, 'Discriminator')
self.print_network(self.nets.S, 'Style Encoder')
self.print_network(self.nets.F, 'Mapping')
self.print_network(self.nets.G, 'Generator')
def build_solver(self):
if self.args.mode == 'train':
self.optims = Munch()
for net in self.nets.keys():
self.optims[net] = torch.optim.Adam(
params=self.nets[net].parameters(),
lr=self.args.f_lr if 'F' in net else self.args.lr,
betas=[self.args.beta1, self.args.beta2],
weight_decay=self.args.weight_decay)
if self.dist.size() > 1 and self.dist.horovod:
# Horovod: broadcast parameters & optimizer state.
self.dist.broadcast_parameters(self.nets[net].state_dict(),
root_rank=0)
self.dist.broadcast_optimizer_state(self.optims[net],
root_rank=0)
self.optims[net] = self.dist.DistributedOptimizer(
self.optims[net],
named_parameters=self.nets[net].named_parameters(),
op=self.dist.hvd.Average)
self.nets_ema = copy.deepcopy(self.nets)
# Start with trained model
if self.args.pretrained_model:
self.load_pretrained_model()
self._to_cuda()
if self.args.FAN and self.args.dataset != 'DeepFashion2':
from misc.wing import FAN
general_attr = self.args.GENERAL_HEATMAP
self.nets.FAN = FAN(
fname_pretrained='models/pretrained_models/wing.ckpt',
general_attr=general_attr)
self.nets.FAN = to_cuda(self.nets.FAN, fixed=True)
self.nets_ema.FAN = self.nets.FAN
n_domains = list(self.args.data_module.selected_attrs)
_sr = '{} domains involved: {}'.format(len(n_domains), str(n_domains))
self.PRINT(_sr)
# ============================================================#
# ============================================================#
def update_lr(self, lr, f_lr):
for key in self.optims.keys():
if 'F' in key:
for param_group in self.optims[key].param_groups:
param_group['lr'] = f_lr
else:
for param_group in self.optims[key].param_groups:
param_group['lr'] = lr
# ============================================================#
# ============================================================#
def reset_grad(self):
for optim in self.optims.values():
optim.zero_grad()
# ==================================================================#
# ==================================================================#
def _to_cuda(self):
for key in self.nets.keys():
self.nets[key] = to_cuda(self.nets[key], fixed=self.args.HOROVOD)
for key in self.nets_ema.keys():
# self.nets_ema[key] = to_cuda(self.nets_ema[key], fixed=True)
self.nets_ema[key] = to_cuda(self.nets_ema[key],
fixed=self.args.HOROVOD)