Skip to content

Commit d789eaa

Browse files
committed
fix logger repetitive output
1 parent 1a4da1f commit d789eaa

9 files changed

+44
-15
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
runs/
2+
__pycache__/

README.md

100644100755
File mode changed.

compute_results.py

100644100755
File mode changed.

data.py

100644100755
File mode changed.

evaluation.py

100644100755
+2-2
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def i2t(images, captions, npts=None, measure='cosine',
240240
"""
241241
if npts is None:
242242
npts = int(images.shape[0] / 5)
243-
print(npts)
243+
#print(npts)
244244
index_list = []
245245

246246
scores = images.dot(captions.T)
@@ -297,7 +297,7 @@ def t2i(images, captions, npts=None, measure='cosine',
297297
"""
298298
if npts is None:
299299
npts = int(images.shape[0] / 5)
300-
print(npts)
300+
#print("# points:", npts)
301301
ims = numpy.array([images[i] for i in range(0, len(images), 5)])
302302

303303
scores = captions.dot(ims.T)

model.py

100644100755
File mode changed.

requirements.txt

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Automatically generated by https://github.com/damnever/pigar.
2+
3+
# HAL/data.py: 6
4+
Pillow == 7.0.0
5+
6+
# HAL/data.py: 5
7+
# HAL/vocab.py: 2
8+
nltk == 3.4.5
9+
10+
# HAL/data.py: 8
11+
# HAL/evaluation.py: 5,8
12+
# HAL/model.py: 10
13+
numpy == 1.18.1
14+
15+
# HAL/data.py: 7
16+
# HAL/vocab.py: 5
17+
pycocotools-fix == 2.0.0.9
18+
19+
# HAL/data.py: 7
20+
# HAL/vocab.py: 5
21+
pycocotools-win == 2.0
22+
23+
# HAL/train.py: 9
24+
tensorboard_logger == 0.1.0
25+
26+
# HAL/data.py: 1,2
27+
# HAL/evaluation.py: 10
28+
# HAL/model.py: 7,8,9
29+
# HAL/train.py: 7
30+
torch == 1.5.1
31+
32+
# HAL/data.py: 3
33+
# HAL/model.py: 5
34+
torchvision == 0.6.0a0+35d732a

train.py

100644100755
+6-13
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,17 @@
22
import os
33
import time
44
import shutil
5-
5+
from random import random
6+
import argparse
67
import torch
8+
import logging
9+
import tensorboard_logger as tb_logger
710

811
import data
912
from vocab import Vocabulary # NOQA
1013
from model import VSE
1114
from evaluation import i2t, t2i, AverageMeter, LogCollector, encode_data
1215

13-
import logging
14-
import tensorboard_logger as tb_logger
15-
from random import random
16-
17-
import argparse
18-
19-
2016
def main():
2117
# Hyper Parameters
2218
parser = argparse.ArgumentParser()
@@ -101,10 +97,6 @@ def main():
10197

10298
logging.basicConfig(format='%(message)s', level=logging.INFO)
10399
tb_logger.configure(opt.logger_name, flush_secs=5)
104-
logger = logging.getLogger()
105-
sh = logging.StreamHandler()
106-
logger.addHandler(sh)
107-
logger.setLevel(logging.INFO)
108100

109101
# Load Vocabulary Wrapper
110102
vocab = pickle.load(open(os.path.join(
@@ -149,7 +141,7 @@ def main():
149141

150142
# evaluate on validation set
151143
rsum = validate(opt, val_loader, model)
152-
print ("rsum:",rsum)
144+
print ("rsum: %.1f" % rsum)
153145
if opt.record_val:
154146
with open("rst_val_" + opt.logger_name[5:], "a") as f:
155147
f.write("Epoch: %d ; rsum: %.1f\n" %(epoch, rsum))
@@ -283,6 +275,7 @@ def validate(opt, val_loader, model):
283275
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', prefix='', save_all=True):
284276
torch.save(state, prefix + filename)
285277
if is_best:
278+
print ("[Best model sofar, saved.]")
286279
shutil.copyfile(prefix + filename, prefix + 'model_best.pth.tar')
287280
if save_all:
288281
shutil.copyfile(prefix + filename, prefix + "Epoch-" + str(state['epoch']) + "-" + 'model.pth.tar')

vocab.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)