Skip to content

Commit 459067c

Browse files
rewrite loss function with tf ops, such that it's differentiable
1 parent 7b651ee commit 459067c

File tree

9 files changed

+231
-133
lines changed

9 files changed

+231
-133
lines changed

src/Rect.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { BoundingBox } from './BoundingBox';
12
export interface IRect {
23
x: number
34
y: number
@@ -45,4 +46,8 @@ export class Rect implements IRect {
4546
Math.floor(this.height)
4647
)
4748
}
49+
50+
public toBoundingBox(): BoundingBox {
51+
return new BoundingBox(this.x, this.y, this.x + this.width, this.y + this.height)
52+
}
4853
}

src/tinyYolov2/TinyYolov2.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
9595
const out = await this.forwardInput(netInput, inputSize)
9696
const out0 = tf.tidy(() => tf.unstack(out)[0].expandDims()) as tf.Tensor4D
9797

98-
console.log(out0.shape)
99-
10098
const inputDimensions = {
10199
width: netInput.getInputWidth(0),
102100
height: netInput.getInputHeight(0)
@@ -147,7 +145,7 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
147145
for (let col = 0; col < numCells; col ++) {
148146
for (let anchor = 0; anchor < NUM_BOXES; anchor ++) {
149147
const score = sigmoid(scoresTensor.get(row, col, anchor, 0))
150-
if (score > scoreThreshold) {
148+
if (!scoreThreshold || score > scoreThreshold) {
151149
const ctX = ((col + sigmoid(boxesTensor.get(row, col, anchor, 0))) / numCells) * paddings.x
152150
const ctY = ((row + sigmoid(boxesTensor.get(row, col, anchor, 1))) / numCells) * paddings.y
153151
const width = ((Math.exp(boxesTensor.get(row, col, anchor, 2)) * this.anchors[anchor].x) / numCells) * paddings.x

src/tinyYolov2/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ export type TinyYolov2ForwardParams = {
4646
}
4747

4848
export type PostProcessingParams = {
49-
scoreThreshold: number
49+
scoreThreshold?: number
5050
paddings: Point
5151
}

tools/train/faceLandmarks/train.html

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,11 @@
7070
ts = Date.now() - ts
7171
console.log('step %s done (%s ms)', i, ts)
7272
if (((i + 1) % saveEveryNthIteration) === 0) {
73-
saveWeights(i)
73+
saveWeights(window.trainNet, 'landmark_trained_weights_' + idx + '.weights')
7474
}
7575
}
7676
}
7777

78-
function saveWeights(idx = 0) {
79-
const binaryWeights = new Float32Array(
80-
window.trainNet.getParamList()
81-
.map(({ tensor }) => Array.from(tensor.dataSync()))
82-
.reduce((flat, arr) => flat.concat(arr))
83-
)
84-
saveAs(new Blob([binaryWeights]), 'landmark_trained_weights_' + idx + '.weights')
85-
}
86-
8778
</script>
8879

8980
</body>

tools/train/serveTinyYolov2.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ app.use(express.static(path.join(__dirname, '../../dist')))
1717
const trainDataPath = path.resolve(process.env.TRAIN_DATA_PATH)
1818
const imagesPath = path.join(trainDataPath, './final_images')
1919
const detectionsPath = path.join(trainDataPath, './final_detections')
20+
app.use(express.static(imagesPath))
21+
app.use(express.static(detectionsPath))
22+
2023
const detectionFilenames = fs.readdirSync(detectionsPath)
2124

2225
app.use(express.static(trainDataPath))

tools/train/shared/trainUtils.js

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,13 @@ function shuffle(a) {
1818
a[j] = x;
1919
}
2020
return a;
21+
}
22+
23+
function saveWeights(net, filename = 'train_tmp') {
24+
const binaryWeights = new Float32Array(
25+
net.getParamList()
26+
.map(({ tensor }) => Array.from(tensor.dataSync()))
27+
.reduce((flat, arr) => flat.concat(arr))
28+
)
29+
saveAs(new Blob([binaryWeights]), filename)
2130
}

tools/train/tinyYolov2/loss.js

Lines changed: 171 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,104 +3,195 @@ const objectScale = 1
33
const noObjectScale = 0.5
44
const coordScale = 5
55

6-
const squared = e => Math.pow(e, 2)
6+
const CELL_SIZE = 32
77

8-
const isSameAnchor = (p1, p2) =>
9-
p1.row === p2.row
10-
&& p1.col === p2.col
11-
&& p1.anchor === p2.anchor
8+
const getNumCells = inputSize => inputSize / CELL_SIZE
129

13-
const sum = vals => vals.reduce((sum, val) => sum + val, 0)
14-
15-
function computeNoObjectLoss(negative) {
16-
return squared(0 - negative.score)
10+
function getAnchors() {
11+
return window.net.anchors
1712
}
1813

19-
function computeObjectLoss({ groundTruth, pred }) {
20-
return squared(
21-
faceapi.iou(
22-
groundTruth.box,
23-
pred.box
24-
)
25-
- pred.score
26-
)
27-
}
14+
function assignBoxesToAnchors(groundTruthBoxes, reshapedImgDims) {
2815

29-
function computeCoordLoss({ groundTruth, pred }, imgDims) {
30-
const anchor = window.net.anchors[groundTruth.anchor]
31-
const getWidthCorrections = box => Math.log((box.width / imgDims.width) / anchor.x)
32-
const getHeightCorrections = box => Math.log((box.height / imgDims.height) / anchor.y)
16+
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
17+
const numCells = getNumCells(inputSize)
3318

34-
return squared((groundTruth.box.left - pred.box.left) / imgDims.width)
35-
+ squared((groundTruth.box.top - pred.box.top) / imgDims.height)
36-
+ squared(getWidthCorrections(groundTruth.box) - getWidthCorrections(pred.box))
37-
+ squared(getHeightCorrections(groundTruth.box) - getHeightCorrections(pred.box))
38-
}
39-
40-
function computeLoss(outBoxesByAnchor, groundTruth, imgDims) {
41-
42-
const { anchors } = window.net
43-
const inputSize = Math.max(imgDims.width, imgDims.height)
44-
const numCells = inputSize / 32
45-
46-
const groundTruthByAnchor = groundTruth.map(rect => {
47-
const x = rect.x * imgDims.width
48-
const y = rect.y * imgDims.height
49-
const width = rect.width * imgDims.width
50-
const height = rect.height * imgDims.height
19+
return groundTruthBoxes.map(box => {
20+
const { left: x, top: y, width, height } = box.rescale(reshapedImgDims)
5121

5222
const row = Math.round((y / inputSize) * numCells)
5323
const col = Math.round((x / inputSize) * numCells)
54-
const anchorsByIou = anchors.map((a, idx) => ({
24+
25+
const anchorsByIou = getAnchors().map((anchor, idx) => ({
5526
idx,
5627
iou: faceapi.iou(
57-
new faceapi.BoundingBox(0, 0, a.x * 32, a.y * 32),
28+
new faceapi.BoundingBox(0, 0, anchor.x * CELL_SIZE, anchor.y * CELL_SIZE),
5829
new faceapi.BoundingBox(0, 0, width, height)
5930
)
6031
})).sort((a1, a2) => a2.iou - a1.iou)
6132

62-
console.log('anchorsByIou', anchorsByIou)
63-
6433
const anchor = anchorsByIou[0].idx
6534

66-
return {
67-
box: new faceapi.BoundingBox(x, y, x + width, y + height),
68-
row,
69-
col,
70-
anchor
35+
return { row, col, anchor, box }
36+
})
37+
}
38+
39+
function getGroundTruthMask(groundTruthBoxes, inputSize) {
40+
41+
const numCells = getNumCells(inputSize)
42+
43+
const mask = tf.zeros([numCells, numCells, 25])
44+
const buf = mask.buffer()
45+
46+
groundTruthBoxes.forEach(({ row, col, anchor }) => {
47+
const anchorOffset = anchor * 5
48+
for (let i = 0; i < 5; i++) {
49+
buf.set(1, row, col, anchorOffset + i)
7150
}
7251
})
7352

74-
console.log('outBoxesByAnchor', outBoxesByAnchor.filter(o => o.score > 0.5).map(o => o))
75-
console.log('outBoxesByAnchor', outBoxesByAnchor.filter(o => o.score > 0.5).map(o => o.box.rescale(imgDims)))
76-
console.log('groundTruthByAnchor', groundTruthByAnchor)
77-
78-
const negatives = outBoxesByAnchor.filter(pred => !groundTruthByAnchor.find(gt => isSameAnchor(gt, pred)))
79-
const positives = outBoxesByAnchor
80-
.map(pred => ({
81-
groundTruth: groundTruthByAnchor.find(gt => isSameAnchor(gt, pred)),
82-
pred: {
83-
...pred,
84-
box: pred.box.rescale(imgDims)
85-
}
86-
}))
87-
.filter(pos => !!pos.groundTruth)
88-
89-
90-
console.log('negatives', negatives)
91-
console.log('positives', positives)
92-
93-
const noObjectLoss = sum(negatives.map(computeNoObjectLoss))
94-
const objectLoss = sum(positives.map(computeObjectLoss))
95-
const coordLoss = sum(positives.map(positive => computeCoordLoss(positive, imgDims)))
96-
97-
console.log('noObjectLoss', noObjectLoss)
98-
console.log('objectLoss', objectLoss)
99-
console.log('coordLoss', coordLoss)
100-
101-
return noObjectScale * noObjectLoss
102-
+ objectScale * objectLoss
103-
+ coordScale * coordLoss
104-
// we don't compute a class loss, since we only have 1 class
105-
// + class_scale * sum(class_loss)
53+
return mask
54+
}
55+
56+
function computeBoxAdjustments(groundTruthBoxes, reshapedImgDims) {
57+
58+
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
59+
const numCells = getNumCells(inputSize)
60+
61+
const adjustments = tf.zeros([numCells, numCells, 25])
62+
const buf = adjustments.buffer()
63+
64+
groundTruthBoxes.forEach(({ row, col, anchor, box }) => {
65+
const { left, top, right, bottom, width, height } = box.rescale(reshapedImgDims)
66+
67+
const centerX = (left + right) / 2
68+
const centerY = (top + bottom) / 2
69+
const dx = (centerX - (col * CELL_SIZE + (CELL_SIZE / 2))) / inputSize
70+
const dy = (centerY - (row * CELL_SIZE + (CELL_SIZE / 2))) / inputSize
71+
const dw = Math.log(width / getAnchors()[anchor].x)
72+
const dh = Math.log(height / getAnchors()[anchor].y)
73+
74+
const anchorOffset = anchor * 5
75+
buf.set(dx, row, col, anchorOffset + 0)
76+
buf.set(dy, row, col, anchorOffset + 1)
77+
buf.set(dw, row, col, anchorOffset + 2)
78+
buf.set(dh, row, col, anchorOffset + 3)
79+
})
80+
81+
return adjustments
82+
}
83+
84+
function computeIous(predBoxes, groundTruthBoxes, reshapedImgDims) {
85+
86+
const numCells = getNumCells(Math.max(reshapedImgDims.width, reshapedImgDims.height))
87+
88+
const isSameAnchor = p1 => p2 =>
89+
p1.row === p2.row
90+
&& p1.col === p2.col
91+
&& p1.anchor === p2.anchor
92+
93+
const ious = tf.zeros([numCells, numCells, 25])
94+
const buf = ious.buffer()
95+
96+
groundTruthBoxes.forEach(({ row, col, anchor, box }) => {
97+
const predBox = predBoxes.find(isSameAnchor({ row, col, anchor }))
98+
99+
if (!predBox) {
100+
console.log(groundTruthBoxes)
101+
console.log(predBoxes)
102+
throw new Error(`no output box found for: row ${row}, col ${col}, anchor ${anchor}`)
103+
}
104+
105+
const iou = faceapi.iou(
106+
box.rescale(reshapedImgDims),
107+
predBox.box.rescale(reshapedImgDims)
108+
)
109+
110+
const anchorOffset = anchor * 5
111+
buf.set(iou, row, col, anchorOffset + 4)
112+
})
113+
114+
return ious
115+
}
116+
117+
function computeNoObjectLoss(outTensor) {
118+
return tf.tidy(() => tf.square(tf.sigmoid(outTensor)))
119+
}
120+
121+
function computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, paddings) {
122+
return tf.tidy(() => {
123+
const predBoxes = window.net.postProcess(
124+
outTensor,
125+
{ paddings }
126+
)
127+
const ious = computeIous(
128+
predBoxes,
129+
groundTruthBoxes,
130+
reshapedImgDims
131+
)
132+
133+
return tf.square(tf.sub(ious, tf.sigmoid(outTensor)))
134+
})
135+
}
136+
137+
function computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims) {
138+
return tf.tidy(() => {
139+
const boxAdjustments = computeBoxAdjustments(
140+
groundTruthBoxes,
141+
reshapedImgDims
142+
)
143+
144+
return tf.square(tf.sub(boxAdjustments, outTensor))
145+
})
146+
}
147+
148+
function computeLoss(outTensor, groundTruth, reshapedImgDims, paddings) {
149+
150+
const inputSize = Math.max(reshapedImgDims.width, reshapedImgDims.height)
151+
152+
if (!inputSize) {
153+
throw new Error(`invalid inputSize: ${inputSize}`)
154+
}
155+
156+
let groundTruthBoxes = assignBoxesToAnchors(
157+
groundTruth
158+
.map(({ x, y, width, height }) => new faceapi.Rect(x, y, width, height))
159+
.map(rect => rect.toBoundingBox()),
160+
reshapedImgDims
161+
)
162+
163+
const mask = getGroundTruthMask(
164+
groundTruthBoxes,
165+
inputSize
166+
)
167+
const inverseMask = tf.tidy(() => tf.sub(tf.scalar(1), mask))
168+
169+
const noObjectLoss = tf.tidy(() =>
170+
tf.mul(
171+
tf.scalar(noObjectScale),
172+
tf.sum(tf.mul(inverseMask, computeNoObjectLoss(outTensor)))
173+
)
174+
)
175+
const objectLoss = tf.tidy(() =>
176+
tf.mul(
177+
tf.scalar(objectScale),
178+
tf.sum(tf.mul(mask, computeObjectLoss(outTensor, groundTruthBoxes, reshapedImgDims, paddings)))
179+
)
180+
)
181+
182+
const coordLoss = tf.tidy(() =>
183+
tf.mul(
184+
tf.scalar(coordScale),
185+
tf.sum(tf.mul(mask, computeCoordLoss(groundTruthBoxes, outTensor, reshapedImgDims)))
186+
)
187+
)
188+
189+
const totalLoss = tf.tidy(() => noObjectLoss.add(objectLoss).add(coordLoss))
190+
191+
return {
192+
noObjectLoss,
193+
objectLoss,
194+
coordLoss,
195+
totalLoss
196+
}
106197
}

0 commit comments

Comments
 (0)