Skip to content

Commit df85315

Browse files
author
kmbae
committed
Final update
1 parent 86dd11d commit df85315

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ This code is a fork from the code for "Learning to Discover Cross-Domain Relatio
1212

1313
$ pip install -r requirements.txt
1414

15-
15+
Pytorch==0.4.1 required
1616
## Start training
1717

1818
$ python main.py

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch
1+
torch==0.4.1
22
numpy
33
Pillow
44
glob

trainer.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ def train(self):
209209
valid_x_A1, valid_x_B1=torch.Tensor(np.load(self.config.dataset_A2+'_A.npy')), torch.Tensor(np.load(self.config.dataset_A2+'_B.npy'))
210210
valid_x_A1, valid_x_B1 = self._get_variable(valid_x_A1), self._get_variable(valid_x_B1)
211211
except:
212-
print('Cannot load validation file. Validation data not created')
213-
assert 1
212+
raise Exception('Cannot load validation file. Validation data not created')
214213
'''
215214
x_A1 = A_loader.next()
216215
x_A2 = A1_loader.next()
@@ -255,10 +254,6 @@ def train(self):
255254

256255
x_A1, x_B1 = x_A1['image'], x_A1['edges']
257256
x_A2, x_B2 = x_A2['image'], x_A2['edges']
258-
#writer.add_image('x_A1', x_A1[:16],step)
259-
#writer.add_image('x_A2', x_A2[:16],step)
260-
#writer.add_image('x_B1', x_B1[:16],step)
261-
#writer.add_image('x_B2', x_B2[:16],step)
262257
if x_A1.size(0) != x_B1.size(0) or x_A2.size(0) != x_B2.size(0) or x_A1.size(0) != x_A2.size(0):
263258
print("[!] Sampled dataset from A and B have different # of data. Try resampling...")
264259
continue
@@ -269,7 +264,23 @@ def train(self):
269264
batch_size = x_A1.size(0)
270265
real_tensor.data.resize_(batch_size).fill_(real_label)
271266
fake_tensor.data.resize_(batch_size).fill_(fake_label)
272-
267+
"""
268+
####### Debugging
269+
ipdb.set_trace()
270+
x_A1.data.resize_(x_A1.shape).fill_(real_label)
271+
x_B1.data.resize_(x_B1.shape).fill_(real_label)
272+
x_A2.data.resize_(x_A2.shape).fill_(real_label)
273+
x_B2.data.resize_(x_B2.shape).fill_(real_label)
274+
275+
self.G.load_state_dict(torch.load('G_0.pth'))
276+
self.F.load_state_dict(torch.load('F_0.pth'))
277+
278+
self.D_S.load_state_dict(torch.load('D_A2_0.pth'))
279+
self.D_H.load_state_dict(torch.load('D_A1_0.pth'))
280+
281+
self.D_B1.load_state_dict(torch.load('D_B1_0.pth'))
282+
self.D_B2.load_state_dict(torch.load('D_B2_0.pth'))
283+
"""
273284
## Update Db network
274285
#self.D_B1.zero_grad()
275286
#self.D_B2.zero_grad()
@@ -478,11 +489,11 @@ def train(self):
478489
step)
479490
writer.add_scalars('loss_F', {'l_const_A':l_const_A, 'l_const_AB': l_const_AB}, step)
480491
#'l_const_B':l_const_B,'l_const_AB':l_const_AB,'l_const_BA':l_const_BA}, step)
481-
writer.add_scalars('loss_D', {'l_d_A':l_d_A,'l_d_B':l_d_B}, step)
482-
writer.add_image('x_A1', x_A1, step)
483-
writer.add_image('x_B1', x_B1, step)
484-
writer.add_image('x_A2', x_A2, step)
485-
writer.add_image('x_B2', x_B2, step)'''
492+
writer.add_scalars('loss_D', {'l_d_A':l_d_A,'l_d_B':l_d_B}, step)'''
493+
#writer.add_image('x_A1', x_A1[:16], step)
494+
#writer.add_image('x_B1', x_B1[:16], step)
495+
#writer.add_image('x_A2', x_A2[:16], step)
496+
#writer.add_image('x_B2', x_B2[:16], step)
486497
# Discriminator loss
487498
writer.add_scalar('L_d_A', l_dA, step)
488499
writer.add_scalar('L_d_B', l_dB, step)
@@ -498,8 +509,11 @@ def train(self):
498509
torch.save(self.G.state_dict(), '{}/G_{}.pth'.format(self.model_dir, step))
499510
torch.save(self.F.state_dict(), '{}/F_{}.pth'.format(self.model_dir, step))
500511

501-
torch.save(self.D_S.state_dict(), '{}/D_A_{}.pth'.format(self.model_dir, step))
502-
torch.save(self.D_H.state_dict(), '{}/D_B_{}.pth'.format(self.model_dir, step))
512+
torch.save(self.D_H.state_dict(), '{}/D_A1_{}.pth'.format(self.model_dir, step))
513+
torch.save(self.D_S.state_dict(), '{}/D_A2_{}.pth'.format(self.model_dir, step))
514+
515+
torch.save(self.D_B1.state_dict(), '{}/D_B1_{}.pth'.format(self.model_dir, step))
516+
torch.save(self.D_B2.state_dict(), '{}/D_B2_{}.pth'.format(self.model_dir, step))
503517

504518
def generate_with_A(self, inputs, input_ref, path, idx=None, tf_board=True):
505519
x_AB = self.F(inputs)
@@ -524,7 +538,7 @@ def generate_with_A(self, inputs, input_ref, path, idx=None, tf_board=True):
524538
writer.add_image('x_A1valid', inputs[:16], idx)
525539
writer.add_image('x_A1G', x_ABA[:16], idx)
526540
#writer.add_image('x_B1_2f', x_ABAf[:16], idx)
527-
writer.add_image('x_A1_rec', x_ABAB[:16], idx)
541+
writer.add_image('x_A1rec', x_ABAB[:16], idx)
528542
#writer.add_image('x_ABA', x_ABA, idx)
529543
#vutils.save_image(x_ABA.data, x_ABA_path)
530544
#print("[*] Samples saved: {}".format(x_ABA_path))
@@ -552,7 +566,7 @@ def generate_with_B(self, inputs, input_ref, path, idx=None, tf_board=True):
552566
writer.add_image('x_A2valid', inputs[:16], idx)
553567
writer.add_image('x_A2G', x_BAB[:16], idx)
554568
#writer.add_image('x_B1_1f', x_ABAf[:16], idx)
555-
writer.add_image('x_A2_rec', x_ABAB[:16], idx)
569+
writer.add_image('x_A2rec', x_ABAB[:16], idx)
556570
#writer.add_image('x_BAB', x_BAB, idx)
557571
#vutils.save_image(x_BAB.data, x_BAB_path)
558572
#print("[*] Samples saved: {}".format(x_BAB_path))
@@ -573,9 +587,9 @@ def generate_infinitely(self, inputs, path, input_type, count=10, nrow=2, idx=No
573587

574588
def test(self):
575589
batch_size = self.config.sample_per_image
576-
x_A1, x_B1 = torch.Tensor(np.load('../valid_x_A1.npy')), torch.Tensor(np.load('../valid_x_B1.npy'))
590+
x_A1, x_B1 = torch.Tensor(np.load('./valid_x_A1.npy')), torch.Tensor(np.load('./valid_x_B1.npy'))
577591
x_A1, x_B1 = self._get_variable(x_A1), self._get_variable(x_B1)
578-
x_A2, x_B2=torch.Tensor(np.load('../valid_x_A2.npy')), torch.Tensor(np.load('../valid_x_B2.npy'))
592+
x_A2, x_B2=torch.Tensor(np.load('./valid_x_A2.npy')), torch.Tensor(np.load('./valid_x_B2.npy'))
579593
x_A2, x_B2 = self._get_variable(x_A2), self._get_variable(x_B2)
580594

581595
test_dir = os.path.join(self.model_dir, 'test')

0 commit comments

Comments
 (0)