forked from jmsbrcwll/SpikeProp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainNetwork.m
54 lines (46 loc) · 1.5 KB
/
trainNetwork.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
%this script will run everything
load('spiketime_inputs.mat');
input_fire_times = peakLocs(:,1:8);
load('spikeTimings_desired.mat');
desired_output_fire_times = peakLocs;
[input_fire_times, desired_output_fire_times] = normalizeSounds(input_fire_times, desired_output_fire_times);
layer_node_num = zeros(4,1);
layer_node_num(1) =2;
layer_node_num(2) =2;
layer_node_num(3) =2;
layer_node_num(4) =1;
weights = zeros(3,2,2);
weights(1,1:2,:) = rand(2,2) + 1/3 ;
weights(2,:,:) = rand(2,2) + 1/3;
weights(3,:,1) = rand(2,1) + 1/3 ;
%
% for i = 1:3
% for j = 1:8
% weights(i,j,j) = 4;
% end
%
% end
errors = zeros(size(input_fire_times,1));
%
% for i = 1:size(input_fire_times,1)
% errors(i)= getError(desired_output_fire_times(i,:)',input_fire_times(i,:));
%
%
% end
input_fire_times = [0 0.06; 0.06 0; 0 0.6; 0.1 0.1];
desired_output_fire_times = [0.2; 0.4; 0.4; 0.2];
realMeanError = sum(errors) / 15;
meanErrorLog = [];
errors = zeros(size(input_fire_times,1));
for iter = 1:1000
%loop through training examples
for i = 1:size(input_fire_times, 1)
[weights, fire_times] = spikePropAlgorithm( input_fire_times(i,:), desired_output_fire_times(i,:),weights, 1, layer_node_num);
errors(i) = getError(desired_output_fire_times(i,:)',fire_times(4,:));
end
meanError = sum(errors) / 4;
meanErrorLog = [meanErrorLog; meanError];
if mod(iter,50) == 0
hello = 5;
end
end