|
17 | 17 | <script>
|
18 | 18 | tf = faceapi.tf
|
19 | 19 |
|
20 |
| - const weightsUrl = '/tmp/test.weights' |
21 |
| - const fromEpoch = 800 |
| 20 | + const weightsUrl = '/tmp/initial_tiny_yolov2_glorot_normal.weights' |
| 21 | + const fromEpoch = 0 |
22 | 22 |
|
23 | 23 |
|
24 |
| - window.debug = true |
25 |
| - window.logTrainSteps = true |
| 24 | + window.debug = false |
| 25 | + window.logTrainSteps = false |
26 | 26 |
|
27 | 27 |
|
28 | 28 | // hyper parameters
|
29 | 29 | window.objectScale = 5
|
30 | 30 | window.noObjectScale = 1
|
31 | 31 | window.coordScale = 1
|
32 | 32 |
|
33 |
| - window.saveEveryNthIteration = 50 |
| 33 | + window.saveEveryNthIteration = 1 |
34 | 34 | window.trainSteps = 4000
|
35 | 35 | //window.optimizer = tf.train.sgd(0.001)
|
36 | 36 | window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
|
37 | 37 |
|
38 |
| - const numTrainSamples = 1 |
| 38 | + // all samples |
| 39 | + const numTrainSamples = Infinity |
39 | 40 |
|
40 | 41 | async function loadNetWeights(uri) {
|
41 | 42 | return new Float32Array(await (await fetch(uri)).arrayBuffer())
|
|
50 | 51 | window.net = new faceapi.TinyYolov2(true)
|
51 | 52 | window.net.load(weights)
|
52 | 53 | window.net.variable()
|
53 |
| - window.detectionFilenames = (await fetchDetectionFilenames()).slice(1, numTrainSamples + 1) |
| 54 | + window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples) |
54 | 55 | window.lossMap = {}
|
55 | 56 |
|
56 | 57 | console.log('ready')
|
57 | 58 | }
|
58 | 59 |
|
59 |
| - const trainSizes = [608] |
| 60 | + const trainSizes = [320] |
| 61 | + |
| 62 | + function logLossChange(lossType) { |
| 63 | + const { currentLoss, prevLoss, detectionFilenames } = window |
| 64 | + log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`) |
| 65 | + } |
60 | 66 |
|
61 | 67 | async function train(batchSize = 1) {
|
62 | 68 | for (let i = fromEpoch; i < trainSteps; i++) {
|
|
81 | 87 | log()
|
82 | 88 | log('step %s done (%s ms)', i, ts)
|
83 | 89 |
|
84 |
| - const currentLoss = Object.keys(lossMap).map(k => lossMap[k]).reduce((sum, l) => sum + l) |
| 90 | + window.prevLoss = window.currentLoss |
| 91 | + window.currentLoss = Object.keys(lossMap) |
| 92 | + .map(filename => lossMap[filename]) |
| 93 | + .reduce((accumulatedLosses, losses) => |
| 94 | + Object.keys(losses) |
| 95 | + .map(key => ({ |
| 96 | + [key]: (accumulatedLosses[key] || 0) + losses[key] |
| 97 | + })) |
| 98 | + .reduce((map, curr) => ({ ...map, ...curr }), {}), |
| 99 | + {} |
| 100 | + ) |
| 101 | + |
85 | 102 | if (window.prevLoss) {
|
86 |
| - log('prevLoss:', window.prevLoss) |
87 |
| - log('currentLoss:', currentLoss) |
88 |
| - log('loss change:', currentLoss - window.prevLoss) |
| 103 | + logLossChange('noObjectLoss') |
| 104 | + logLossChange('objectLoss') |
| 105 | + logLossChange('coordLoss') |
| 106 | + logLossChange('totalLoss') |
89 | 107 | }
|
90 | 108 | log()
|
91 | 109 | log('--------------------')
|
92 | 110 | log()
|
93 | 111 |
|
94 |
| - window.prevLoss = currentLoss |
95 |
| - |
96 |
| - |
97 | 112 | if (((i + 1) % saveEveryNthIteration) === 0) {
|
98 |
| - saveWeights(window.net, 'adam_511_' + i + '.weights') |
| 113 | + saveWeights(window.net, 'adam_511_' + (i + 1) + '.weights') |
99 | 114 | }
|
100 | 115 | }
|
101 | 116 | }
|
|
0 commit comments