@@ -3,104 +3,195 @@ const objectScale = 1
3
3
const noObjectScale = 0.5
4
4
const coordScale = 5
5
5
6
- const squared = e => Math . pow ( e , 2 )
6
+ const CELL_SIZE = 32
7
7
8
- const isSameAnchor = ( p1 , p2 ) =>
9
- p1 . row === p2 . row
10
- && p1 . col === p2 . col
11
- && p1 . anchor === p2 . anchor
8
+ const getNumCells = inputSize => inputSize / CELL_SIZE
12
9
13
- const sum = vals => vals . reduce ( ( sum , val ) => sum + val , 0 )
14
-
15
- function computeNoObjectLoss ( negative ) {
16
- return squared ( 0 - negative . score )
10
+ function getAnchors ( ) {
11
+ return window . net . anchors
17
12
}
18
13
19
- function computeObjectLoss ( { groundTruth, pred } ) {
20
- return squared (
21
- faceapi . iou (
22
- groundTruth . box ,
23
- pred . box
24
- )
25
- - pred . score
26
- )
27
- }
14
+ function assignBoxesToAnchors ( groundTruthBoxes , reshapedImgDims ) {
28
15
29
- function computeCoordLoss ( { groundTruth, pred } , imgDims ) {
30
- const anchor = window . net . anchors [ groundTruth . anchor ]
31
- const getWidthCorrections = box => Math . log ( ( box . width / imgDims . width ) / anchor . x )
32
- const getHeightCorrections = box => Math . log ( ( box . height / imgDims . height ) / anchor . y )
16
+ const inputSize = Math . max ( reshapedImgDims . width , reshapedImgDims . height )
17
+ const numCells = getNumCells ( inputSize )
33
18
34
- return squared ( ( groundTruth . box . left - pred . box . left ) / imgDims . width )
35
- + squared ( ( groundTruth . box . top - pred . box . top ) / imgDims . height )
36
- + squared ( getWidthCorrections ( groundTruth . box ) - getWidthCorrections ( pred . box ) )
37
- + squared ( getHeightCorrections ( groundTruth . box ) - getHeightCorrections ( pred . box ) )
38
- }
39
-
40
- function computeLoss ( outBoxesByAnchor , groundTruth , imgDims ) {
41
-
42
- const { anchors } = window . net
43
- const inputSize = Math . max ( imgDims . width , imgDims . height )
44
- const numCells = inputSize / 32
45
-
46
- const groundTruthByAnchor = groundTruth . map ( rect => {
47
- const x = rect . x * imgDims . width
48
- const y = rect . y * imgDims . height
49
- const width = rect . width * imgDims . width
50
- const height = rect . height * imgDims . height
19
+ return groundTruthBoxes . map ( box => {
20
+ const { left : x , top : y , width, height } = box . rescale ( reshapedImgDims )
51
21
52
22
const row = Math . round ( ( y / inputSize ) * numCells )
53
23
const col = Math . round ( ( x / inputSize ) * numCells )
54
- const anchorsByIou = anchors . map ( ( a , idx ) => ( {
24
+
25
+ const anchorsByIou = getAnchors ( ) . map ( ( anchor , idx ) => ( {
55
26
idx,
56
27
iou : faceapi . iou (
57
- new faceapi . BoundingBox ( 0 , 0 , a . x * 32 , a . y * 32 ) ,
28
+ new faceapi . BoundingBox ( 0 , 0 , anchor . x * CELL_SIZE , anchor . y * CELL_SIZE ) ,
58
29
new faceapi . BoundingBox ( 0 , 0 , width , height )
59
30
)
60
31
} ) ) . sort ( ( a1 , a2 ) => a2 . iou - a1 . iou )
61
32
62
- console . log ( 'anchorsByIou' , anchorsByIou )
63
-
64
33
const anchor = anchorsByIou [ 0 ] . idx
65
34
66
- return {
67
- box : new faceapi . BoundingBox ( x , y , x + width , y + height ) ,
68
- row,
69
- col,
70
- anchor
35
+ return { row, col, anchor, box }
36
+ } )
37
+ }
38
+
39
+ function getGroundTruthMask ( groundTruthBoxes , inputSize ) {
40
+
41
+ const numCells = getNumCells ( inputSize )
42
+
43
+ const mask = tf . zeros ( [ numCells , numCells , 25 ] )
44
+ const buf = mask . buffer ( )
45
+
46
+ groundTruthBoxes . forEach ( ( { row, col, anchor } ) => {
47
+ const anchorOffset = anchor * 5
48
+ for ( let i = 0 ; i < 5 ; i ++ ) {
49
+ buf . set ( 1 , row , col , anchorOffset + i )
71
50
}
72
51
} )
73
52
74
- console . log ( 'outBoxesByAnchor' , outBoxesByAnchor . filter ( o => o . score > 0.5 ) . map ( o => o ) )
75
- console . log ( 'outBoxesByAnchor' , outBoxesByAnchor . filter ( o => o . score > 0.5 ) . map ( o => o . box . rescale ( imgDims ) ) )
76
- console . log ( 'groundTruthByAnchor' , groundTruthByAnchor )
77
-
78
- const negatives = outBoxesByAnchor . filter ( pred => ! groundTruthByAnchor . find ( gt => isSameAnchor ( gt , pred ) ) )
79
- const positives = outBoxesByAnchor
80
- . map ( pred => ( {
81
- groundTruth : groundTruthByAnchor . find ( gt => isSameAnchor ( gt , pred ) ) ,
82
- pred : {
83
- ...pred ,
84
- box : pred . box . rescale ( imgDims )
85
- }
86
- } ) )
87
- . filter ( pos => ! ! pos . groundTruth )
88
-
89
-
90
- console . log ( 'negatives' , negatives )
91
- console . log ( 'positives' , positives )
92
-
93
- const noObjectLoss = sum ( negatives . map ( computeNoObjectLoss ) )
94
- const objectLoss = sum ( positives . map ( computeObjectLoss ) )
95
- const coordLoss = sum ( positives . map ( positive => computeCoordLoss ( positive , imgDims ) ) )
96
-
97
- console . log ( 'noObjectLoss' , noObjectLoss )
98
- console . log ( 'objectLoss' , objectLoss )
99
- console . log ( 'coordLoss' , coordLoss )
100
-
101
- return noObjectScale * noObjectLoss
102
- + objectScale * objectLoss
103
- + coordScale * coordLoss
104
- // we don't compute a class loss, since we only have 1 class
105
- // + class_scale * sum(class_loss)
53
+ return mask
54
+ }
55
+
56
+ function computeBoxAdjustments ( groundTruthBoxes , reshapedImgDims ) {
57
+
58
+ const inputSize = Math . max ( reshapedImgDims . width , reshapedImgDims . height )
59
+ const numCells = getNumCells ( inputSize )
60
+
61
+ const adjustments = tf . zeros ( [ numCells , numCells , 25 ] )
62
+ const buf = adjustments . buffer ( )
63
+
64
+ groundTruthBoxes . forEach ( ( { row, col, anchor, box } ) => {
65
+ const { left, top, right, bottom, width, height } = box . rescale ( reshapedImgDims )
66
+
67
+ const centerX = ( left + right ) / 2
68
+ const centerY = ( top + bottom ) / 2
69
+ const dx = ( centerX - ( col * CELL_SIZE + ( CELL_SIZE / 2 ) ) ) / inputSize
70
+ const dy = ( centerY - ( row * CELL_SIZE + ( CELL_SIZE / 2 ) ) ) / inputSize
71
+ const dw = Math . log ( width / getAnchors ( ) [ anchor ] . x )
72
+ const dh = Math . log ( height / getAnchors ( ) [ anchor ] . y )
73
+
74
+ const anchorOffset = anchor * 5
75
+ buf . set ( dx , row , col , anchorOffset + 0 )
76
+ buf . set ( dy , row , col , anchorOffset + 1 )
77
+ buf . set ( dw , row , col , anchorOffset + 2 )
78
+ buf . set ( dh , row , col , anchorOffset + 3 )
79
+ } )
80
+
81
+ return adjustments
82
+ }
83
+
84
+ function computeIous ( predBoxes , groundTruthBoxes , reshapedImgDims ) {
85
+
86
+ const numCells = getNumCells ( Math . max ( reshapedImgDims . width , reshapedImgDims . height ) )
87
+
88
+ const isSameAnchor = p1 => p2 =>
89
+ p1 . row === p2 . row
90
+ && p1 . col === p2 . col
91
+ && p1 . anchor === p2 . anchor
92
+
93
+ const ious = tf . zeros ( [ numCells , numCells , 25 ] )
94
+ const buf = ious . buffer ( )
95
+
96
+ groundTruthBoxes . forEach ( ( { row, col, anchor, box } ) => {
97
+ const predBox = predBoxes . find ( isSameAnchor ( { row, col, anchor } ) )
98
+
99
+ if ( ! predBox ) {
100
+ console . log ( groundTruthBoxes )
101
+ console . log ( predBoxes )
102
+ throw new Error ( `no output box found for: row ${ row } , col ${ col } , anchor ${ anchor } ` )
103
+ }
104
+
105
+ const iou = faceapi . iou (
106
+ box . rescale ( reshapedImgDims ) ,
107
+ predBox . box . rescale ( reshapedImgDims )
108
+ )
109
+
110
+ const anchorOffset = anchor * 5
111
+ buf . set ( iou , row , col , anchorOffset + 4 )
112
+ } )
113
+
114
+ return ious
115
+ }
116
+
117
+ function computeNoObjectLoss ( outTensor ) {
118
+ return tf . tidy ( ( ) => tf . square ( tf . sigmoid ( outTensor ) ) )
119
+ }
120
+
121
+ function computeObjectLoss ( outTensor , groundTruthBoxes , reshapedImgDims , paddings ) {
122
+ return tf . tidy ( ( ) => {
123
+ const predBoxes = window . net . postProcess (
124
+ outTensor ,
125
+ { paddings }
126
+ )
127
+ const ious = computeIous (
128
+ predBoxes ,
129
+ groundTruthBoxes ,
130
+ reshapedImgDims
131
+ )
132
+
133
+ return tf . square ( tf . sub ( ious , tf . sigmoid ( outTensor ) ) )
134
+ } )
135
+ }
136
+
137
+ function computeCoordLoss ( groundTruthBoxes , outTensor , reshapedImgDims ) {
138
+ return tf . tidy ( ( ) => {
139
+ const boxAdjustments = computeBoxAdjustments (
140
+ groundTruthBoxes ,
141
+ reshapedImgDims
142
+ )
143
+
144
+ return tf . square ( tf . sub ( boxAdjustments , outTensor ) )
145
+ } )
146
+ }
147
+
148
+ function computeLoss ( outTensor , groundTruth , reshapedImgDims , paddings ) {
149
+
150
+ const inputSize = Math . max ( reshapedImgDims . width , reshapedImgDims . height )
151
+
152
+ if ( ! inputSize ) {
153
+ throw new Error ( `invalid inputSize: ${ inputSize } ` )
154
+ }
155
+
156
+ let groundTruthBoxes = assignBoxesToAnchors (
157
+ groundTruth
158
+ . map ( ( { x, y, width, height } ) => new faceapi . Rect ( x , y , width , height ) )
159
+ . map ( rect => rect . toBoundingBox ( ) ) ,
160
+ reshapedImgDims
161
+ )
162
+
163
+ const mask = getGroundTruthMask (
164
+ groundTruthBoxes ,
165
+ inputSize
166
+ )
167
+ const inverseMask = tf . tidy ( ( ) => tf . sub ( tf . scalar ( 1 ) , mask ) )
168
+
169
+ const noObjectLoss = tf . tidy ( ( ) =>
170
+ tf . mul (
171
+ tf . scalar ( noObjectScale ) ,
172
+ tf . sum ( tf . mul ( inverseMask , computeNoObjectLoss ( outTensor ) ) )
173
+ )
174
+ )
175
+ const objectLoss = tf . tidy ( ( ) =>
176
+ tf . mul (
177
+ tf . scalar ( objectScale ) ,
178
+ tf . sum ( tf . mul ( mask , computeObjectLoss ( outTensor , groundTruthBoxes , reshapedImgDims , paddings ) ) )
179
+ )
180
+ )
181
+
182
+ const coordLoss = tf . tidy ( ( ) =>
183
+ tf . mul (
184
+ tf . scalar ( coordScale ) ,
185
+ tf . sum ( tf . mul ( mask , computeCoordLoss ( groundTruthBoxes , outTensor , reshapedImgDims ) ) )
186
+ )
187
+ )
188
+
189
+ const totalLoss = tf . tidy ( ( ) => noObjectLoss . add ( objectLoss ) . add ( coordLoss ) )
190
+
191
+ return {
192
+ noObjectLoss,
193
+ objectLoss,
194
+ coordLoss,
195
+ totalLoss
196
+ }
106
197
}
0 commit comments