1
- async function trainStep ( batchCreators ) {
1
+ async function trainStep ( batchCreators , inputSize ) {
2
+
2
3
await promiseSequential ( batchCreators . map ( ( batchCreator , dataIdx ) => async ( ) => {
3
4
4
- const { batchInput, groundTruthBoxes } = await batchCreator ( )
5
- /*
5
+ // TODO: skip if groundTruthBoxes are too tiny
6
+ const { imgs, groundTruthBoxes } = await batchCreator ( )
7
+
8
+ const batchInput = ( await faceapi . toNetInput ( imgs ) ) . managed ( )
9
+
6
10
let ts = Date . now ( )
7
- const cost = optimizer.minimize(() => {
8
- const out = window.trainNet.forwardInput(batchInput.managed())
9
- const loss = lossFunction(
10
- landmarksBatchTensor,
11
- out
11
+ const loss = optimizer . minimize ( ( ) => {
12
+ const outTensor = window . net . forwardInput ( batchInput , inputSize )
13
+ const outTensorsByBatch = tf . tidy ( ( ) => outTensor . unstack ( ) . expandDims ( ) )
14
+ outTensor . dispose ( )
15
+
16
+ const losses = outTensorsByBatch . map (
17
+ ( out , batchIdx ) => {
18
+ const outBoxesByAnchor = window . net . postProcess (
19
+ out ,
20
+ {
21
+ scoreThreshold : - 1 ,
22
+ paddings : batchInput . getRelativePaddings ( batchIdx )
23
+ }
24
+ )
25
+
26
+ const loss = computeLoss (
27
+ outBoxesByAnchor ,
28
+ groundTruthBoxes [ batchIdx ] ,
29
+ netInput . getReshapedInputDimensions ( batchIdx )
30
+ )
31
+
32
+ console . log ( `loss for batch ${ batchIdx } : ${ loss } ` )
33
+
34
+ return loss
35
+ }
12
36
)
13
- return loss
37
+
38
+ outTensorsByBatch . forEach ( t => t . dispose ( ) )
39
+
40
+ return losses . reduce ( ( sum , loss ) => sum + loss , 0 )
14
41
} , true )
15
42
16
43
ts = Date . now ( ) - ts
17
- console.log(`loss[${dataIdx}]: ${await cost.data()}, ${ts} ms (${ts / batchInput.batchSize} ms / batch element)`)
18
-
19
- landmarksBatchTensor.dispose()
20
- cost.dispose()
44
+ console . log ( `loss[${ dataIdx } ]: ${ loss } , ${ ts } ms (${ ts / batchInput . batchSize } ms / batch element)` )
21
45
22
46
await tf . nextFrame ( )
23
47
} ) )
24
- */
25
48
}
26
49
27
50
function createBatchCreators ( batchSize ) {
@@ -42,17 +65,16 @@ function createBatchCreators(batchSize) {
42
65
pushToBatch ( window . detectionFilenames )
43
66
44
67
const batchCreators = batches . map ( detectionFilenames => async ( ) => {
45
- const imgs = detectionFilenames . map (
68
+ const groundTruthBoxes = detectionFilenames . map (
46
69
detectionFilenames . map ( file => fetch ( file ) . then ( res => res . json ( ) ) )
47
70
)
48
- const groundTruthBoxes = await Promise . all (
71
+
72
+ const imgs = await Promise . all (
49
73
detectionFilenames . map ( async file => await faceapi . bufferToImage ( await fetchImage ( file . replace ( '.json' , '' ) ) ) )
50
74
)
51
75
52
- const batchInput = await faceapi . toNetInput ( imgs )
53
-
54
76
return {
55
- batchInput ,
77
+ imgs ,
56
78
groundTruthBoxes
57
79
}
58
80
} )
0 commit comments