-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathvsum_train.py
158 lines (138 loc) · 6.46 KB
/
vsum_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import theano
from theano import tensor as T
import theano_nets
from model_reinforceRNN import reinforceRNN
import numpy as np
from datetime import datetime
import time, math, os, sys, h5py, logging, vsum_tools, argparse
from scipy.spatial.distance import cdist
_DTYPE = theano.config.floatX
def train(n_episodes=5,
input_dim=1024,
hidden_dim=256,
W_init='normal',
U_init='normal',
weight_decay=1e-5,
regularizer='L2',
optimizer='adam',
base_lr=1e-5,
decay_rate=0.1,
max_epochs=60,
decay_stepsize=30,
ignore_distant_sim=True,
distant_sim_thre=20,
alpha=0.01,
model_file=None,
disp_freq=1,
train_dataset_path='datasets/eccv16_dataset_tvsum_google_pool5.h5',
):
model_options = locals().copy()
log_dir = 'log-train'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logging.basicConfig(
filename=log_dir+'/log.txt',
filemode='w',
format='%(asctime)s %(message)s',
datefmt='[%d/%m/%Y %I:%M:%S]',
level=logging.INFO
)
logger = logging.getLogger()
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s',datefmt='[%d/%m/%Y %I:%M:%S]')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.info('model options: ' + str(model_options))
logger.info('initializing net model')
net = reinforceRNN(model_options)
if model_file is not None: logger.info('loaded model from %s' % model_file)
n_params = net.get_n_params()
logger.info('net params size is %d' % (n_params))
logger.info('loading training data from %s' % (train_dataset_path))
dataset = h5py.File(train_dataset_path, 'r')
if sys.version_info[0] == 3:
dataset_keys = list(dataset.keys())
else:
dataset_keys = dataset.keys()
n_videos = len(dataset_keys)
logger.info('total number of videos for training is %d' % n_videos)
logger.info('=> training')
start_time = time.time()
blrwds = {name:np.array(0).astype(_DTYPE) for name in dataset_keys} # baseline rewards
for i_epoch in range(max_epochs):
indices = np.arange(n_videos)
np.random.shuffle(indices)
epoch_reward = 0.
if decay_stepsize != -1 and i_epoch >= decay_stepsize:
power_n = int(i_epoch/decay_stepsize)
learn_rate = base_lr * (decay_rate**power_n)
else:
learn_rate = base_lr
for index in indices:
key = dataset_keys[index]
data_x = dataset[key]['features'][...].astype(_DTYPE)
L_distance_mat = cdist(data_x, data_x, 'euclidean')
L_dissim_mat = 1 - np.dot(data_x, data_x.T)
if ignore_distant_sim:
inds = np.arange(data_x.shape[0])[:,None]
inds_dist = cdist(inds, inds, 'minkowski', 1)
L_dissim_mat[inds_dist > distant_sim_thre] = 1
rewards = net.model_train(data_x, learn_rate, L_dissim_mat, L_distance_mat, blrwds[key])
blrwds[key] = 0.9 * blrwds[key] + 0.1 * rewards.mean()
epoch_reward += rewards.mean()
epoch_reward /= n_videos
if (i_epoch+1) % disp_freq == 0 or (i_epoch+1) == max_epochs:
logger.info('epoch %3d/%d. reward %f.' % \
(i_epoch+1, max_epochs, epoch_reward))
elapsed_time = time.time() - start_time
logger.info('elapsed time %.2f s' % (elapsed_time))
net.save_net(save_dir=log_dir, model_name='model_reinforceRNN')
dataset.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--n-epi', type=int, default=5,
help="number of episodes for REINFORCE")
parser.add_argument('--input-dim', type=int, default=1024,
help="input dimension, i.e. dimension of CNN features")
parser.add_argument('--hidden-dim', type=int, default=256,
help="hidden dimension of RNN")
parser.add_argument('--W-init', type=str, default='normal', choices=theano_nets.init_choices(),
help="initialization method for non-recurrent weights")
parser.add_argument('--U-init', type=str, default='normal', choices=theano_nets.init_choices(),
help="initialization method for recurrent weights")
parser.add_argument('--weight-decay', type=float, default=1e-5,
help="coefficient for regularization on weight parameters")
parser.add_argument('--reg', type=str, default='L2', choices=theano_nets.reg_choices(),
help="regularizer for weight parameters")
parser.add_argument('--optim', type=str, default='adam', choices=theano_nets.optim_choices())
parser.add_argument('--base-lr', type=float, default=1e-5, help="base learning rate")
parser.add_argument('--decay-rate', type=float, default=0.1, help="learning rate decay")
parser.add_argument('--max-epochs', type=int, default=60, help="maximum training epochs")
parser.add_argument('--decay-stepsize', type=int, default=-1,
help="stepsize to decay learning rate, if -1, then learning rate decay is disabled")
parser.add_argument('--ignore-distant-sim', action='store_true',
help="whether to ignore the similarity between distant frames")
parser.add_argument('--distant-sim-thre', type=int, default=20,
help="threshold to ignore similarity between distant frames")
parser.add_argument('--alpha', type=float, default=0.01, help="coefficient for summary length penalty")
parser.add_argument('--disp-freq', type=int, default=1, help="display frequency")
parser.add_argument('--dataset', type=str, default='datasets/eccv16_dataset_summe_google_pool5.h5')
args = parser.parse_args()
train(n_episodes=args.n_epi,
input_dim=args.input_dim,
hidden_dim=args.hidden_dim,
W_init=args.W_init,
U_init=args.U_init,
weight_decay=args.weight_decay,
regularizer=args.reg,
optimizer=args.optim,
base_lr=args.base_lr,
decay_rate=args.decay_rate,
max_epochs=args.max_epochs,
decay_stepsize=args.decay_stepsize,
ignore_distant_sim=args.ignore_distant_sim,
distant_sim_thre=args.distant_sim_thre,
alpha=args.alpha,
disp_freq=args.disp_freq,
train_dataset_path=args.dataset)