Skip to content

Commit bcb9883

Browse files
check in latest training script
1 parent 6c3c55e commit bcb9883

File tree

5 files changed

+62
-21
lines changed

5 files changed

+62
-21
lines changed

tools/train/serveTinyYolov2.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ app.use(express.static(imagesPath))
2121
app.use(express.static(detectionsPath))
2222

2323
const detectionFilenames = fs.readdirSync(detectionsPath)
24+
const detectionFilenamesMultibox = JSON.parse(fs.readFileSync(path.join(__dirname, './tinyYolov2/multibox.json')))
2425

2526
app.use(express.static(trainDataPath))
2627

2728
app.get('/detection_filenames', (req, res) => res.status(202).send(detectionFilenames))
29+
app.get('/detection_filenames_multibox', (req, res) => res.status(202).send(detectionFilenamesMultibox))
2830
app.get('/', (req, res) => res.sendFile(path.join(publicDir, 'train.html')))
2931
app.get('/verify', (req, res) => res.sendFile(path.join(publicDir, 'verify.html')))
3032

tools/train/tinyYolov2/multibox.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tools/train/tinyYolov2/train.html

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
<script>
1818
tf = faceapi.tf
1919

20-
const startIdx224 = 3220
21-
const startIdx320 = 20688
22-
const startIdx416 = 950
23-
const startIdx608 = 15220
20+
const startIdx224 = 35060
21+
const startIdx320 = 41188
22+
const startIdx416 = 31050
23+
const startIdx608 = 16520
2424

2525
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
26+
27+
//const weightsUrl = '/tmp/tmp_2_count_41000.weights'
2628
const fromEpoch = 0
2729

30+
const trainOnlyMultibox = false
2831

2932
window.debug = false
3033
window.logTrainSteps = true
@@ -42,6 +45,8 @@
4245
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
4346

4447
// all samples
48+
//const dataStartIdx = 8000
49+
const dataStartIdx = 0
4550
const numTrainSamples = Infinity
4651

4752
async function loadNetWeights(uri) {
@@ -52,28 +57,51 @@
5257
return fetch('/detection_filenames').then(res => res.json())
5358
}
5459

60+
async function fetchDetectionFilenamesMultibox() {
61+
return fetch('/detection_filenames_multibox').then(res => res.json())
62+
}
63+
5564
async function run() {
5665
const weights = await loadNetWeights(weightsUrl)
5766
window.net = new faceapi.TinyYolov2(true)
5867
window.net.load(weights)
5968
window.net.variable()
60-
window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples)
69+
70+
const fetchDetectionsFn = trainOnlyMultibox
71+
? fetchDetectionFilenamesMultibox
72+
: fetchDetectionFilenames
73+
74+
window.detectionFilenames = (await fetchDetectionsFn()).slice(dataStartIdx, dataStartIdx + numTrainSamples)
75+
6176
window.lossMap = {}
6277

6378
console.log('ready')
6479
}
6580

66-
//const trainSizes = [224, 320, 416, 608]
81+
//const trainSizes = [224, 320, 416]
6782
const trainSizes = [608]
6883

6984
function logLossChange(lossType) {
7085
const { currentLoss, prevLoss, detectionFilenames } = window
7186
log(`${lossType} : ${faceapi.round(currentLoss[lossType])} (avg: ${faceapi.round(currentLoss[lossType] / detectionFilenames.length)}) (delta: ${currentLoss[lossType] - prevLoss[lossType]})`)
7287
}
7388

89+
window.count = 0
90+
91+
function _onBatchProcessed(dataIdx, inputSize) {
92+
window.count++
93+
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
94+
console.log('dataIdx', dataIdx)
95+
if ((window.count % saveEveryNthDataIdx) === 0) {
96+
saveWeights(window.net, `tmp_2_count_${window.count}.weights`)
97+
}
98+
}
99+
74100
function onBatchProcessed(dataIdx, inputSize) {
75-
if (((dataIdx + 1) % saveEveryNthDataIdx) === 0) {
76-
saveWeights(window.net, `tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608 + dataIdx + 1}.weights`)
101+
const idx = (dataIdx + 1) + (window.epoch * window.detectionFilenames.length)
102+
console.log('idx', idx)
103+
if ((idx % saveEveryNthDataIdx) === 0) {
104+
saveWeights(window.net, `tmp__224_${startIdx224 + (inputSize === 224 ? idx : 0)}__320_${startIdx320 + (inputSize === 320 ? idx : 0)}__416_${startIdx416 + (inputSize === 416 ? idx : 0)}__608_${startIdx608 + (inputSize === 608 ? idx : 0)}.weights`)
77105
}
78106
}
79107

@@ -82,6 +110,7 @@
82110
const batchSize = 1
83111

84112
for (let i = fromEpoch; i < trainSteps; i++) {
113+
window.epoch = i
85114
log('step', i)
86115
let ts2 = Date.now()
87116

tools/train/tinyYolov2/train.js

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,24 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
7373
.rescale({ height: imgHeight, width: imgWidth })
7474
.rescale(scaleFactor)
7575

76-
const isTooTiny = box.width < 50 || box.height < 50
77-
if (isTooTiny) {
76+
const isTooTiny = box.width < 40 || box.height < 40
77+
if (isTooTiny && window.debug) {
7878
log(`skipping box for input size ${inputSize}: (${Math.floor(box.width)} x ${Math.floor(box.height)})`)
7979
}
8080
return !isTooTiny
8181
})
8282

8383
if (!filteredGroundTruthBoxes.length) {
84-
log(`no boxes for input size ${inputSize}, ${groundTruthBoxes[batchIdx].length} boxes were too small`)
84+
if (window.debug) {
85+
log(`no boxes for input size ${inputSize}, ${groundTruthBoxes[batchIdx].length} boxes were too small`)
86+
}
8587
batchInput.dispose()
8688
onBatchProcessed(dataIdx, inputSize)
8789
return
8890
}
8991

9092
let ts = Date.now()
9193
const loss = minimize(filteredGroundTruthBoxes, batchInput, inputSize, batch)
92-
9394
ts = Date.now() - ts
9495
if (window.logTrainSteps) {
9596
log(`trainStep time for dataIdx ${dataIdx} (${inputSize}): ${ts} ms`)
@@ -109,7 +110,15 @@ async function trainStep(batchCreators, inputSizes, rescaleEveryNthBatch, onBatc
109110
await step(batchCreators.next(rescaleEveryNthBatch))
110111
}
111112

112-
function createBatchCreators(detectionFilenames, batchSize, ) {
113+
async function fetchGroundTruthBoxesForFile(file) {
114+
const boxes = await fetch(file).then(res => res.json())
115+
return {
116+
file,
117+
boxes
118+
}
119+
}
120+
121+
function createBatchCreators(detectionFilenames, batchSize) {
113122
if (batchSize < 1) {
114123
throw new Error('invalid batch size: ' + batchSize)
115124
}
@@ -126,9 +135,8 @@ function createBatchCreators(detectionFilenames, batchSize, ) {
126135
pushToBatch(detectionFilenames)
127136

128137
const batchCreators = batches.map((filenamesForBatch, dataIdx) => async () => {
129-
const groundTruthBoxes = await Promise.all(filenamesForBatch.map(
130-
file => fetch(file).then(res => res.json())
131-
))
138+
const groundTruthBoxes = (await Promise.all(filenamesForBatch.map(fetchGroundTruthBoxesForFile)))
139+
.map(({ boxes }) => boxes)
132140

133141
const imgs = await Promise.all(filenamesForBatch.map(
134142
async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', '')))

tools/train/tinyYolov2/verify.html

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,19 @@
138138

139139
async function run() {
140140
$('#imgByNr').keydown(onKeyDown)
141-
const startIdx224 = 3220
142-
const startIdx320 = 20688
143-
const startIdx416 = 950
144-
const startIdx608 = 15220
141+
142+
const startIdx224 = 35060
143+
const startIdx320 = 41188
144+
const startIdx416 = 31050
145+
const startIdx608 = 16520
145146

146147
const weightsUrl = `/tmp/tmp__224_${startIdx224}__320_${startIdx320}__416_${startIdx416}__608_${startIdx608}.weights`
147148

148149
const weights = await loadNetWeights(weightsUrl)
149150
window.net = new faceapi.TinyYolov2(true)
150151
await window.net.load(weights)
151152

152-
window.imgs = (await fetchDetectionFilenames()).slice(0, 100).map(f => f.replace('.json', ''))
153+
window.imgs = (await fetchDetectionFilenames()).map(f => f.replace('.json', ''))
153154

154155
$('#loader').hide()
155156
onSelectionChanged($('#selectList select').val())

0 commit comments

Comments
 (0)