Skip to content

Commit 3afe611

Browse files
fixed loss function
1 parent 99cbd7e commit 3afe611

File tree

3 files changed

+38
-54
lines changed

3 files changed

+38
-54
lines changed

tools/train/tinyYolov2/loss.js

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@ function assignBoxesToAnchors(groundTruthBoxes, reshapedImgDims) {
1818
const numCells = getNumCells(inputSize)
1919

2020
return groundTruthBoxes.map(box => {
21-
const { left: x, top: y, width, height } = box.rescale(reshapedImgDims)
21+
const { left, top, width, height } = box.rescale(reshapedImgDims)
2222

23-
const row = Math.round((y / inputSize) * numCells)
24-
const col = Math.round((x / inputSize) * numCells)
23+
const ctX = left + (width / 2)
24+
const ctY = top + (height / 2)
25+
26+
const col = Math.floor((ctX / inputSize) * numCells)
27+
const row = Math.floor((ctY / inputSize) * numCells)
2528

2629
const anchorsByIou = getAnchors().map((anchor, idx) => ({
2730
idx,
@@ -92,11 +95,13 @@ function computeBoxAdjustments(groundTruthBoxes, reshapedImgDims) {
9295
const centerX = (left + right) / 2
9396
const centerY = (top + bottom) / 2
9497

95-
const dCenterX = centerX - (col * CELL_SIZE + (CELL_SIZE / 2))
96-
const dCenterY = centerY - (row * CELL_SIZE + (CELL_SIZE / 2))
98+
//const dCenterX = centerX - (col * CELL_SIZE + (CELL_SIZE / 2))
99+
//const dCenterY = centerY - (row * CELL_SIZE + (CELL_SIZE / 2))
100+
const dCenterX = centerX - (col * CELL_SIZE)
101+
const dCenterY = centerY - (row * CELL_SIZE)
97102

98-
const dx = inverseSigmoid(dCenterX / inputSize)
99-
const dy = inverseSigmoid(dCenterY / inputSize)
103+
const dx = inverseSigmoid(dCenterX / CELL_SIZE)
104+
const dy = inverseSigmoid(dCenterY / CELL_SIZE)
100105
const dw = Math.log((width / CELL_SIZE) / getAnchors()[anchor].x)
101106
const dh = Math.log((height / CELL_SIZE) / getAnchors()[anchor].y)
102107

@@ -134,13 +139,14 @@ function computeIous(predBoxes, groundTruthBoxes, reshapedImgDims) {
134139

135140
const iou = faceapi.iou(
136141
box.rescale(reshapedImgDims),
137-
predBox.box
142+
predBox.box.rescale(reshapedImgDims)
138143
)
139144

140145
if (window.debug) {
141-
console.log('ground thruth box:', box.rescale(reshapedImgDims))
142-
console.log('predicted box:', predBox.box)
143-
console.log(iou)
146+
console.log('ground thruth box:', box.rescale(reshapedImgDims).toRect())
147+
console.log('predicted box:', predBox.box.rescale(reshapedImgDims).toRect())
148+
console.log('predicted score:', predBox.score)
149+
console.log('iou:', iou)
144150
}
145151

146152
const anchorOffset = anchor * 5
@@ -164,31 +170,6 @@ function computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, padding
164170
{ paddings }
165171
)
166172

167-
if (window.debug) {
168-
console.log(predBoxes)
169-
console.log(predBoxes.filter(b => b.score > 0.1))
170-
}
171-
172-
// debug
173-
174-
const numCells = getNumCells(Math.max(reshapedImgDims.width, reshapedImgDims.height))
175-
if (predBoxes.length !== (numCells * numCells * getAnchors().length)) {
176-
console.log(predBoxes.length)
177-
throw new Error('predBoxes.length !== (numCells * numCells * 25)')
178-
}
179-
180-
const isInvalid = num => !num && num !== 0
181-
182-
183-
predBoxes.forEach(({ row, col, anchor }) => {
184-
if ([row, col, anchor].some(isInvalid)) {
185-
console.log(row, col, anchor)
186-
throw new Error('row, col, anchor invalid')
187-
}
188-
})
189-
190-
// debug
191-
192173
const ious = computeIous(
193174
predBoxes,
194175
groundTruthBoxes,
@@ -208,7 +189,6 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa
208189
reshapedImgDims
209190
)
210191

211-
// debug
212192
if (window.debug) {
213193
const indToPos = []
214194
const numCells = outTensor.shape[1]
@@ -220,40 +200,39 @@ function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims, mask, pa
220200
}
221201
}
222202

223-
const m = Array.from(mask.dataSync())
224-
const ind = m.map((val, ind) => ({ val, ind })).filter(v => v.val !== 0).map(v => v.ind)
203+
const indices = Array.from(mask.dataSync()).map((val, ind) => ({ val, ind })).filter(v => v.val !== 0).map(v => v.ind)
225204
const gt = Array.from(boxAdjustments.dataSync())
226205
const out = Array.from(outTensor.dataSync())
227206

228-
const comp = ind.map(i => (
207+
const comp = indices.map(i => (
229208
{
230209
pos: indToPos[i],
231210
gt: gt[i],
232211
out: out[i]
233212
}
234213
))
235-
console.log(comp)
236214
console.log(comp.map(c => `gt: ${c.gt}, out: ${c.out}`))
237215

238-
const printBbox = (which) => {
239-
const { col, row, anchor } = comp[0].pos
240-
console.log(col, row, anchor)
216+
const getBbox = (which) => {
217+
const { row, col, anchor } = comp[0].pos
218+
241219
const ctX = ((col + faceapi.sigmoid(comp[0][which])) / numCells) * paddings.x
242220
const ctY = ((row + faceapi.sigmoid(comp[1][which])) / numCells) * paddings.y
243221
const width = ((Math.exp(comp[2][which]) * getAnchors()[anchor].x) / numCells) * paddings.x
244222
const height = ((Math.exp(comp[3][which]) * getAnchors()[anchor].y) / numCells) * paddings.y
245223

246224
const x = (ctX - (width / 2))
247225
const y = (ctY - (height / 2))
248-
console.log(which, x * reshapedImgDims.width, y * reshapedImgDims.height, width * reshapedImgDims.width, height * reshapedImgDims.height)
249-
}
250-
251226

252-
printBbox('out')
253-
printBbox('gt')
227+
return new faceapi.BoundingBox(x, y, x + width, y + height)
228+
}
254229

230+
const outRect = getBbox('out').rescale(reshapedImgDims).toRect()
231+
const gtRect = getBbox('gt').rescale(reshapedImgDims).toRect()
232+
console.log('out', outRect)
233+
console.log('gtRect', gtRect)
255234
}
256-
// debug
235+
257236

258237
const lossTensor = tf.sub(boxAdjustments, outTensor)
259238

tools/train/tinyYolov2/overfit.html

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@
2020
const weightsUrl = '/tmp/test.weights'
2121
const fromEpoch = 800
2222

23+
24+
window.debug = true
25+
window.logTrainSteps = true
26+
27+
2328
// hyper parameters
2429
window.objectScale = 5
25-
window.noObjectScale = 0.5
30+
window.noObjectScale = 1
2631
window.coordScale = 1
2732

2833
window.saveEveryNthIteration = 50
@@ -45,7 +50,7 @@
4550
window.net = new faceapi.TinyYolov2(true)
4651
window.net.load(weights)
4752
window.net.variable()
48-
window.detectionFilenames = (await fetchDetectionFilenames()).slice(0, numTrainSamples)
53+
window.detectionFilenames = (await fetchDetectionFilenames()).slice(1, numTrainSamples + 1)
4954
window.lossMap = {}
5055

5156
console.log('ready')
@@ -90,7 +95,7 @@
9095

9196

9297
if (((i + 1) % saveEveryNthIteration) === 0) {
93-
saveWeights(window.net, 'adam_511_n1_' + i + '.weights')
98+
saveWeights(window.net, 'adam_511_' + i + '.weights')
9499
}
95100
}
96101
}

tools/train/tinyYolov2/verify.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@
138138

139139
async function run() {
140140
$('#imgByNr').keydown(onKeyDown)
141-
const weights = await loadNetWeights('/tmp/test.weights')
141+
const weights = await loadNetWeights('/tmp/test2.weights')
142142
window.net = new faceapi.TinyYolov2(true)
143143
await window.net.load(weights)
144144

0 commit comments

Comments
 (0)