-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathneuralNetwork.js
More file actions
126 lines (98 loc) · 3.02 KB
/
neuralNetwork.js
File metadata and controls
126 lines (98 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import {
Array1D,
InCPUMemoryShuffledInputProviderBuilder,
Graph,
Session,
SGDOptimizer,
ENV,
NDArrayMath,
CostReduction,
} from 'deeplearn';
class MnistModel {
math = ENV.math;
session;
initialLearningRate = 0.06;
optimizer;
batchSize = 300;
inputTensor;
targetTensor;
costTensor;
predictionTensor;
feedEntries;
constructor() {
this.optimizer = new SGDOptimizer(this.initialLearningRate);
}
setupSession(trainingSet) {
const graph = new Graph();
this.inputTensor = graph.placeholder('input unrolled pixels', [784]);
this.targetTensor = graph.placeholder('output digit classifier', [10]);
let fullyConnectedLayer = this.createFullyConnectedLayer(graph, this.inputTensor, 0, 64);
fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 1, 32);
fullyConnectedLayer = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 2, 16);
this.predictionTensor = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 3, 10);
this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
this.session = new Session(graph, this.math);
this.prepareTrainingSet(trainingSet);
}
prepareTrainingSet(trainingSet) {
const oldMath = ENV.math;
const safeMode = false;
const math = new NDArrayMath('cpu', safeMode);
ENV.setMath(math);
const inputArray = trainingSet.map(v => Array1D.new(v.input));
const targetArray = trainingSet.map(v => Array1D.new(v.output));
const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
this.feedEntries = [
{ tensor: this.inputTensor, data: inputProvider },
{ tensor: this.targetTensor, data: targetProvider },
];
ENV.setMath(oldMath);
}
train(step, computeCost) {
let learningRate = this.initialLearningRate * Math.pow(0.90, Math.floor(step / 50));
this.optimizer.setLearningRate(learningRate);
let costValue;
this.math.scope(() => {
const cost = this.session.train(
this.costTensor,
this.feedEntries,
this.batchSize,
this.optimizer,
computeCost ? CostReduction.MEAN : CostReduction.NONE,
);
if (computeCost) {
costValue = cost.get();
}
});
return costValue;
}
predict(pixels) {
let classifier = [];
this.math.scope(() => {
const mapping = [{
tensor: this.inputTensor,
data: Array1D.new(pixels),
}];
classifier = this.session.eval(this.predictionTensor, mapping).getValues();
});
return [ ...classifier ];
}
createFullyConnectedLayer(
graph,
inputLayer,
layerIndex,
units,
activationFunction
) {
return graph.layers.dense(
`fully_connected_${layerIndex}`,
inputLayer,
units,
activationFunction
? activationFunction
: (x) => graph.relu(x)
);
}
}
export default MnistModel;