-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNeuralNetwork.js
75 lines (69 loc) · 1.99 KB
/
NeuralNetwork.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class NeuralNetwork {
constructor(input_nodes, hidden_nodes, output_nodes, model = null) {
this.input_nodes = input_nodes;
this.hidden_nodes = hidden_nodes;
this.output_nodes = output_nodes;
if (model) this.model = model;
else this.model = this.createModel();
}
createModel() {
const model = tf.sequential();
const hidden_layer = tf.layers.dense({
units: this.hidden_nodes,
inputShape: [this.input_nodes],
activation: "sigmoid",
});
const output_layer = tf.layers.dense({
units: this.output_nodes,
activation: "softmax",
});
model.add(hidden_layer);
model.add(output_layer);
return model;
}
predict(inputs) {
return tf.tidy(() => {
const t_inputs = tf.tensor2d([inputs]);
const t_outputs = this.model.predict(t_inputs);
return t_outputs.dataSync();
});
}
copy() {
return tf.tidy(() => {
const modelCopy = this.createModel();
const weights = this.model.getWeights();
const weightCopies = [];
for (let weight of weights) {
weightCopies.push(weight.clone());
}
modelCopy.setWeights(weightCopies);
return new NeuralNetwork(
this.input_nodes,
this.hidden_nodes,
this.output_nodes,
modelCopy
);
});
}
mutate(rate) {
tf.tidy(() => {
const weights = this.model.getWeights();
const mutatedWeights = [];
for (let t_weight of weights) {
let shape = t_weight.shape;
let weight_values = t_weight.dataSync().slice();
for (let i = 0; i < weight_values.length; i++) {
if (random(1) < rate) {
weight_values[i] = weight_values[i] + randomGaussian();
}
}
let newTensor = tf.tensor(weight_values, shape);
mutatedWeights.push(newTensor);
}
this.model.setWeights(mutatedWeights);
});
}
dispose() {
this.model.dispose();
}
}