-
Notifications
You must be signed in to change notification settings - Fork 379
/
main.py
48 lines (40 loc) · 1.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import cPickle as pickle
import gzip
import numpy
from midi_to_statematrix import *
import multi_training
import model
def gen_adaptive(m,pcs,times,keep_thoughts=False,name="final"):
xIpt, xOpt = map(lambda x: numpy.array(x, dtype='int8'), multi_training.getPieceSegment(pcs))
all_outputs = [xOpt[0]]
if keep_thoughts:
all_thoughts = []
m.start_slow_walk(xIpt[0])
cons = 1
for time in range(multi_training.batch_len*times):
resdata = m.slow_walk_fun( cons )
nnotes = numpy.sum(resdata[-1][:,0])
if nnotes < 2:
if cons > 1:
cons = 1
cons -= 0.02
else:
cons += (1 - cons)*0.3
all_outputs.append(resdata[-1])
if keep_thoughts:
all_thoughts.append(resdata)
noteStateMatrixToMidi(numpy.array(all_outputs),'output/'+name)
if keep_thoughts:
pickle.dump(all_thoughts, open('output/'+name+'.p','wb'))
def fetch_train_thoughts(m,pcs,batches,name="trainthoughts"):
all_thoughts = []
for i in range(batches):
ipt, opt = multi_training.getPieceBatch(pcs)
thoughts = m.update_thought_fun(ipt,opt)
all_thoughts.append((ipt,opt,thoughts))
pickle.dump(all_thoughts, open('output/'+name+'.p','wb'))
if __name__ == '__main__':
pcs = multi_training.loadPieces("music")
m = model.Model([300,300],[100,50], dropout=0.5)
multi_training.trainPiece(m, pcs, 10000)
pickle.dump( m.learned_config, open( "output/final_learned_config.p", "wb" ) )