Skip to content

Commit 7b651ee

Browse files
init training function
1 parent aa7dd87 commit 7b651ee

File tree

3 files changed

+57
-45
lines changed

3 files changed

+57
-45
lines changed

tools/train/tinyYolov2/loss.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ function computeCoordLoss({ groundTruth, pred }, imgDims) {
3737
+ squared(getHeightCorrections(groundTruth.box) - getHeightCorrections(pred.box))
3838
}
3939

40-
function computeLoss(outBoxesByAnchor, groundTruth, inputSize, imgDims) {
40+
function computeLoss(outBoxesByAnchor, groundTruth, imgDims) {
4141

4242
const { anchors } = window.net
43+
const inputSize = Math.max(imgDims.width, imgDims.height)
4344
const numCells = inputSize / 32
4445

4546
const groundTruthByAnchor = groundTruth.map(rect => {

tools/train/tinyYolov2/train.html

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
const weightsUrl = '/tmp/initial_tiny_yolov2_glorot_normal.weights'
1919

20-
window.saveEveryNthIteration = 2
20+
window.saveEveryNthIteration = 1
2121
window.trainSteps = 100
2222
window.optimizer = tf.train.adam(0.001, 0.9, 0.999, 1e-8)
2323

@@ -41,38 +41,27 @@
4141
window.detectionFilenames = await fetchDetectionFilenames()
4242
}
4343

44-
/*
44+
const trainSizes = [608, 416, 320, 224]
4545

46-
const outTensor = await window.net.forward(netInput, 608)
47-
const detections = await window.net.locateFaces(netInput, forwardParams)
48-
const outBoxesByAnchor = window.net.postProcess(
49-
outTensor,
50-
{
51-
scoreThreshold: 0,
52-
paddings: netInput.getRelativePaddings(0)
53-
}
54-
)
46+
async function train(batchSize = 1) {
47+
for (let i = 0; i < trainSteps; i++) {
48+
console.log('step', i)
49+
let ts = Date.now()
5550

56-
const groundTruth = detections.map(det => det.forSize(1, 1).box)
51+
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
5752

58-
console.log(computeLoss(
59-
outBoxesByAnchor,
60-
groundTruth,
61-
netInput.inputSize,
62-
netInput.getReshapedInputDimensions(0)
63-
))
53+
for (let s = 0; s < trainSizes.length; s++) {
54+
let ts2 = Date.now()
6455

56+
await trainStep(batchCreators, trainSizes[s])
6557

66-
*/
58+
ts2 = Date.now() - ts2
59+
console.log('train for size %s done (%s ms)', trainSizes[s], ts2)
60+
}
6761

68-
async function train(batchSize = 1) {
69-
for (let i = 0; i < trainSteps; i++) {
70-
console.log('step', i)
71-
const batchCreators = createBatchCreators(shuffle(detectionFilenames), batchSize)
72-
let ts = Date.now()
73-
await trainStep(batchCreators)
7462
ts = Date.now() - ts
7563
console.log('step %s done (%s ms)', i, ts)
64+
7665
if (((i + 1) % saveEveryNthIteration) === 0) {
7766
saveWeights(i)
7867
}

tools/train/tinyYolov2/train.js

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,50 @@
1-
async function trainStep(batchCreators) {
1+
async function trainStep(batchCreators, inputSize) {
2+
23
await promiseSequential(batchCreators.map((batchCreator, dataIdx) => async () => {
34

4-
const { batchInput, groundTruthBoxes } = await batchCreator()
5-
/*
5+
// TODO: skip if groundTruthBoxes are too tiny
6+
const { imgs, groundTruthBoxes } = await batchCreator()
7+
8+
const batchInput = (await faceapi.toNetInput(imgs)).managed()
9+
610
let ts = Date.now()
7-
const cost = optimizer.minimize(() => {
8-
const out = window.trainNet.forwardInput(batchInput.managed())
9-
const loss = lossFunction(
10-
landmarksBatchTensor,
11-
out
11+
const loss = optimizer.minimize(() => {
12+
const outTensor = window.net.forwardInput(batchInput, inputSize)
13+
const outTensorsByBatch = tf.tidy(() => outTensor.unstack().expandDims())
14+
outTensor.dispose()
15+
16+
const losses = outTensorsByBatch.map(
17+
(out, batchIdx) => {
18+
const outBoxesByAnchor = window.net.postProcess(
19+
out,
20+
{
21+
scoreThreshold: -1,
22+
paddings: batchInput.getRelativePaddings(batchIdx)
23+
}
24+
)
25+
26+
const loss = computeLoss(
27+
outBoxesByAnchor,
28+
groundTruthBoxes[batchIdx],
29+
netInput.getReshapedInputDimensions(batchIdx)
30+
)
31+
32+
console.log(`loss for batch ${batchIdx}: ${loss}`)
33+
34+
return loss
35+
}
1236
)
13-
return loss
37+
38+
outTensorsByBatch.forEach(t => t.dispose())
39+
40+
return losses.reduce((sum, loss) => sum + loss, 0)
1441
}, true)
1542

1643
ts = Date.now() - ts
17-
console.log(`loss[${dataIdx}]: ${await cost.data()}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
18-
19-
landmarksBatchTensor.dispose()
20-
cost.dispose()
44+
console.log(`loss[${dataIdx}]: ${loss}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
2145

2246
await tf.nextFrame()
2347
}))
24-
*/
2548
}
2649

2750
function createBatchCreators(batchSize) {
@@ -42,17 +65,16 @@ function createBatchCreators(batchSize) {
4265
pushToBatch(window.detectionFilenames)
4366

4467
const batchCreators = batches.map(detectionFilenames => async () => {
45-
const imgs = detectionFilenames.map(
68+
const groundTruthBoxes = detectionFilenames.map(
4669
detectionFilenames.map(file => fetch(file).then(res => res.json()))
4770
)
48-
const groundTruthBoxes = await Promise.all(
71+
72+
const imgs = await Promise.all(
4973
detectionFilenames.map(async file => await faceapi.bufferToImage(await fetchImage(file.replace('.json', ''))))
5074
)
5175

52-
const batchInput = await faceapi.toNetInput(imgs)
53-
5476
return {
55-
batchInput,
77+
imgs,
5678
groundTruthBoxes
5779
}
5880
})

0 commit comments

Comments
 (0)