Skip to content

Commit a192076

Browse files
author
kmbae
committed
Remove unnecessary dataloader
1 parent 7de0ea0 commit a192076

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

data_loader.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,18 @@ def __len__(self):
123123

124124
def get_loader(root, batch_size, scale_size, num_workers=2,
125125
skip_pix2pix_processing=False, shuffle=True):
126-
a_data_set, b_data_set = \
127-
Dataset(root, scale_size, "A", skip_pix2pix_processing), \
128-
Dataset(root, scale_size, "B", skip_pix2pix_processing)
126+
a_data_set = \
127+
Dataset(root, scale_size, "A", skip_pix2pix_processing)#, \
128+
#Dataset(root, scale_size, "B", skip_pix2pix_processing)
129129
a_data_loader = torch.utils.data.DataLoader(dataset=a_data_set,
130130
batch_size=batch_size,
131131
shuffle=True,
132132
num_workers=num_workers)
133-
b_data_loader = torch.utils.data.DataLoader(dataset=b_data_set,
134-
batch_size=batch_size,
135-
shuffle=True,
136-
num_workers=num_workers)
133+
#b_data_loader = torch.utils.data.DataLoader(dataset=b_data_set,
134+
# batch_size=batch_size,
135+
# shuffle=True,
136+
# num_workers=num_workers)
137137
a_data_loader.shape = a_data_set.shape
138-
b_data_loader.shape = b_data_set.shape
138+
#b_data_loader.shape = b_data_set.shape
139139

140-
return a_data_loader, b_data_loader
140+
return a_data_loader#, b_data_loader

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def main(config):
2222
data_path = config.test_data_path
2323
batch_size = config.sample_per_image
2424

25-
a_data_loader, b_data_loader = get_loader(
25+
a_data_loader = get_loader(
2626
config.dataset_A1, batch_size, config.input_scale_size,
2727
config.num_worker, config.skip_pix2pix_processing)
2828

29-
a1_data_loader, b1_data_loader = get_loader(
29+
a1_data_loader = get_loader(
3030
config.dataset_A2, batch_size, config.input_scale_size,
3131
config.num_worker, config.skip_pix2pix_processing)
3232

33-
trainer = Trainer(config, a_data_loader, b_data_loader, a1_data_loader, b1_data_loader)
33+
trainer = Trainer(config, a_data_loader, a1_data_loader)
3434

3535
if config.is_train:
3636
save_config(config)

trainer.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def weights_init(m):
3232
m.bias.data.fill_(0)
3333

3434
class Trainer(object):
35-
def __init__(self, config, a_data_loader, b_data_loader, a1_data_loader, b1_data_loader):
35+
def __init__(self, config, a_data_loader, a1_data_loader):
3636
self.config = config
3737

3838
self.a_data_loader = a_data_loader
39-
self.b_data_loader = b_data_loader
39+
#self.b_data_loader = b_data_loader
4040
self.a1_data_loader = a1_data_loader
41-
self.b1_data_loader = b1_data_loader
41+
#self.b1_data_loader = b1_data_loader
4242

4343
self.num_gpu = config.num_gpu
4444
self.dataset = config.dataset
@@ -90,7 +90,7 @@ def build_model(self):
9090
self.D_B = DiscriminatorFC(2, 1, [config.fc_hidden_dim] * config.d_num_layer)
9191
else:
9292
a_height, a_width, a_channel = self.a_data_loader.shape
93-
b_height, b_width, b_channel = self.b_data_loader.shape
93+
b_height, b_width, b_channel = self.a1_data_loader.shape
9494

9595
if self.cnn_type == 0:
9696
#conv_dims, deconv_dims = [64, 128, 256, 512], [512, 256, 128, 64]
@@ -253,6 +253,10 @@ def train(self):
253253

254254
x_A1, x_B1 = x_A1['image'], x_A1['edges']
255255
x_A2, x_B2 = x_A2['image'], x_A2['edges']
256+
writer.add_image('x_A1', x_A1[:16],step)
257+
writer.add_image('x_A2', x_A2[:16],step)
258+
writer.add_image('x_B1', x_B1[:16],step)
259+
writer.add_image('x_B2', x_B2[:16],step)
256260
if x_A1.size(0) != x_B1.size(0) or x_A2.size(0) != x_B2.size(0) or x_A1.size(0) != x_A2.size(0):
257261
print("[!] Sampled dataset from A and B have different # of data. Try resampling...")
258262
continue

0 commit comments

Comments
 (0)