Skip to content

Commit 542dc68

Browse files
prepare tinyYolov2 for using depthwise seperable conv2ds + retraining
1 parent f81d35a commit 542dc68

File tree

11 files changed

+189
-84
lines changed

11 files changed

+189
-84
lines changed

src/BoundingBox.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import { Rect } from './Rect';
2+
import { Dimensions } from './types';
3+
import { isDimensions } from './utils';
24

35
export class BoundingBox {
46
constructor(
@@ -33,6 +35,10 @@ export class BoundingBox {
3335
return this.bottom - this.top
3436
}
3537

38+
public get area() : number {
39+
return this.width * this.height
40+
}
41+
3642
public toSquare(): BoundingBox {
3743
let { left, top, right, bottom } = this
3844

@@ -100,6 +106,12 @@ export class BoundingBox {
100106
).toSquare().round()
101107
}
102108

109+
public rescale(s: Dimensions | number) {
110+
const scaleX = isDimensions(s) ? (s as Dimensions).width : s as number
111+
const scaleY = isDimensions(s) ? (s as Dimensions).height : s as number
112+
return new BoundingBox(this.left * scaleX, this.top * scaleY, this.right * scaleX, this.bottom * scaleY)
113+
}
114+
103115
public toRect(): Rect {
104116
return new Rect(this.left, this.top, this.width, this.height)
105117
}

src/commons/extractWeightsFactory.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ export function extractWeightsFactory(weights: Float32Array) {
22
let remainingWeights = weights
33

44
function extractWeights(numWeights: number): Float32Array {
5+
console.log(numWeights)
56
const ret = remainingWeights.slice(0, numWeights)
67
remainingWeights = remainingWeights.slice(numWeights)
78
return ret

src/commons/nonMaxSuppression.ts

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { BoundingBox } from '../BoundingBox';
2+
import { iou } from '../iou';
23

34
export function nonMaxSuppression(
45
boxes: BoundingBox[],
@@ -7,10 +8,6 @@ export function nonMaxSuppression(
78
isIOU: boolean = true
89
): number[] {
910

10-
const areas = boxes.map(
11-
box => (box.width + 1) * (box.height + 1)
12-
)
13-
1411
let indicesSortedByScore = scores
1512
.map((score, boxIndex) => ({ score, boxIndex }))
1613
.sort((c1, c2) => c1.score - c2.score)
@@ -31,15 +28,7 @@ export function nonMaxSuppression(
3128
const currBox = boxes[curr]
3229
const idxBox = boxes[idx]
3330

34-
const width = Math.max(0.0, Math.min(currBox.right, idxBox.right) - Math.max(currBox.left, idxBox.left) + 1)
35-
const height = Math.max(0.0, Math.min(currBox.bottom, idxBox.bottom) - Math.max(currBox.top, idxBox.top) + 1)
36-
const interSection = width * height
37-
38-
const out = isIOU
39-
? interSection / (areas[curr] + areas[idx] - interSection)
40-
: interSection / Math.min(areas[curr], areas[idx])
41-
42-
outputs.push(out)
31+
outputs.push(iou(currBox, idxBox, isIOU))
4332
}
4433

4534
indicesSortedByScore = indicesSortedByScore.filter(

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export * from './faceDetectionNet';
1919
export * from './faceLandmarkNet';
2020
export * from './faceRecognitionNet';
2121
export * from './globalApi';
22+
export * from './iou';
2223
export * from './mtcnn';
2324
export * from './padToSquare';
2425
export * from './tinyYolov2';

src/iou.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import { BoundingBox } from './BoundingBox';
2+
3+
export function iou(box1: BoundingBox, box2: BoundingBox, isIOU: boolean = true) {
4+
const width = Math.max(0.0, Math.min(box1.right, box2.right) - Math.max(box1.left, box2.left) + 1)
5+
const height = Math.max(0.0, Math.min(box1.bottom, box2.bottom) - Math.max(box1.top, box2.top) + 1)
6+
const interSection = width * height
7+
8+
return isIOU
9+
? interSection / (box1.area + box2.area - interSection)
10+
: interSection / Math.min(box1.area, box2.area)
11+
}

src/tinyYolov2/TinyYolov2.ts

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,34 @@ import { NeuralNetwork } from '../commons/NeuralNetwork';
66
import { nonMaxSuppression } from '../commons/nonMaxSuppression';
77
import { FaceDetection } from '../FaceDetection';
88
import { NetInput } from '../NetInput';
9+
import { Point } from '../Point';
910
import { toNetInput } from '../toNetInput';
1011
import { TNetInput } from '../types';
11-
import { BOX_ANCHORS, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES } from './config';
12+
import { sigmoid } from '../utils';
13+
import { BOX_ANCHORS, BOX_ANCHORS_SEPARABLE, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES } from './config';
1214
import { convWithBatchNorm } from './convWithBatchNorm';
1315
import { extractParams } from './extractParams';
1416
import { getDefaultParams } from './getDefaultParams';
1517
import { loadQuantizedParams } from './loadQuantizedParams';
16-
import { NetParams, TinyYolov2ForwardParams } from './types';
18+
import { NetParams, PostProcessingParams, TinyYolov2ForwardParams } from './types';
1719

1820
export class TinyYolov2 extends NeuralNetwork<NetParams> {
1921

20-
constructor() {
22+
private _hasSeparableConvs: boolean
23+
private _anchors: Point[]
24+
25+
constructor(hasSeparableConvs: boolean = false) {
2126
super('TinyYolov2')
27+
this._hasSeparableConvs = hasSeparableConvs
28+
this._anchors = hasSeparableConvs ? BOX_ANCHORS_SEPARABLE : BOX_ANCHORS
29+
}
30+
31+
public get hasSeparableConvs(): boolean {
32+
return this._hasSeparableConvs
33+
}
34+
35+
public get anchors(): Point[] {
36+
return this._anchors
2237
}
2338

2439
public forwardInput(input: NetInput, inputSize: number): tf.Tensor4D {
@@ -30,7 +45,7 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
3045
}
3146

3247
const out = tf.tidy(() => {
33-
const batchTensor = input.toBatchTensor(inputSize, false).div(tf.scalar(255)).toFloat() as tf.Tensor4D
48+
const batchTensor = input.toBatchTensor(inputSize, false).div(tf.scalar(255)) as tf.Tensor4D
3449

3550
let out = convWithBatchNorm(batchTensor, params.conv0)
3651
out = tf.maxPool(out, [2, 2], [2, 2], 'same')
@@ -72,39 +87,74 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
7287

7388
const netInput = await toNetInput(input, true)
7489
const out = await this.forwardInput(netInput, inputSize)
75-
const numCells = out.shape[1]
90+
91+
const inputDimensions = {
92+
width: netInput.getInputWidth(0),
93+
height: netInput.getInputHeight(0)
94+
}
95+
96+
const paddings = new Point(
97+
(netInput.getPaddings(0).x + netInput.getInputWidth(0)) / netInput.getInputWidth(0),
98+
(netInput.getPaddings(0).y + netInput.getInputHeight(0)) / netInput.getInputHeight(0)
99+
)
100+
101+
const results = this.postProcess(out, { scoreThreshold, paddings })
102+
const boxes = results.map(res => res.box)
103+
const scores = results.map(res => res.score)
104+
105+
out.dispose()
106+
107+
const indices = nonMaxSuppression(
108+
boxes.map(box => box.rescale(inputSize)),
109+
scores,
110+
IOU_THRESHOLD,
111+
true
112+
)
113+
114+
const detections = indices.map(idx =>
115+
new FaceDetection(
116+
scores[idx],
117+
boxes[idx].toRect(),
118+
inputDimensions
119+
)
120+
)
121+
122+
return detections
123+
}
124+
125+
public postProcess(outputTensor: tf.Tensor4D, { scoreThreshold, paddings }: PostProcessingParams) {
126+
127+
const numCells = outputTensor.shape[1]
76128

77129
const [boxesTensor, scoresTensor] = tf.tidy(() => {
78-
const reshaped = out.reshape([numCells, numCells, NUM_BOXES, 6])
79-
out.dispose()
130+
const reshaped = outputTensor.reshape([numCells, numCells, NUM_BOXES, this.hasSeparableConvs ? 5 : 6])
80131

81132
const boxes = reshaped.slice([0, 0, 0, 0], [numCells, numCells, NUM_BOXES, 4])
82133
const scores = reshaped.slice([0, 0, 0, 4], [numCells, numCells, NUM_BOXES, 1])
83134
return [boxes, scores]
84135
})
85136

86-
const expit = (x: number): number => 1 / (1 + Math.exp(-x))
87-
88-
const paddedHeightRelative = (netInput.getPaddings(0).y + netInput.getInputHeight(0)) / netInput.getInputHeight(0)
89-
const paddedWidthRelative = (netInput.getPaddings(0).x + netInput.getInputWidth(0)) / netInput.getInputWidth(0)
90-
91-
const boxes: BoundingBox[] = []
92-
const scores: number[] = []
137+
const results = []
93138

94139
for (let row = 0; row < numCells; row ++) {
95140
for (let col = 0; col < numCells; col ++) {
96-
for (let box = 0; box < NUM_BOXES; box ++) {
97-
const score = expit(scoresTensor.get(row, col, box, 0))
141+
for (let anchor = 0; anchor < NUM_BOXES; anchor ++) {
142+
const score = sigmoid(scoresTensor.get(row, col, anchor, 0))
98143
if (score > scoreThreshold) {
99-
const ctX = ((col + expit(boxesTensor.get(row, col, box, 0))) / numCells) * paddedWidthRelative
100-
const ctY = ((row + expit(boxesTensor.get(row, col, box, 1))) / numCells) * paddedHeightRelative
101-
const width = ((Math.exp(boxesTensor.get(row, col, box, 2)) * BOX_ANCHORS[box].x) / numCells) * paddedWidthRelative
102-
const height = ((Math.exp(boxesTensor.get(row, col, box, 3)) * BOX_ANCHORS[box].y) / numCells) * paddedHeightRelative
144+
const ctX = ((col + sigmoid(boxesTensor.get(row, col, anchor, 0))) / numCells) * paddings.x
145+
const ctY = ((row + sigmoid(boxesTensor.get(row, col, anchor, 1))) / numCells) * paddings.y
146+
const width = ((Math.exp(boxesTensor.get(row, col, anchor, 2)) * this.anchors[anchor].x) / numCells) * paddings.x
147+
const height = ((Math.exp(boxesTensor.get(row, col, anchor, 3)) * this.anchors[anchor].y) / numCells) * paddings.y
103148

104149
const x = (ctX - (width / 2))
105150
const y = (ctY - (height / 2))
106-
boxes.push(new BoundingBox(x, y, x + width, y + height))
107-
scores.push(score)
151+
results.push({
152+
box: new BoundingBox(x, y, x + width, y + height),
153+
score,
154+
row,
155+
col,
156+
anchor
157+
})
108158
}
109159
}
110160
}
@@ -113,34 +163,14 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
113163
boxesTensor.dispose()
114164
scoresTensor.dispose()
115165

116-
const indices = nonMaxSuppression(
117-
boxes.map(box => new BoundingBox(
118-
box.left * inputSize,
119-
box.top * inputSize,
120-
box.right * inputSize,
121-
box.bottom * inputSize
122-
)),
123-
scores,
124-
IOU_THRESHOLD,
125-
true
126-
)
127-
128-
const detections = indices.map(idx =>
129-
new FaceDetection(
130-
scores[idx],
131-
boxes[idx].toRect(),
132-
{ width: netInput.getInputWidth(0), height: netInput.getInputHeight(0) }
133-
)
134-
)
135-
136-
return detections
166+
return results
137167
}
138168

139169
protected loadQuantizedParams(uri: string | undefined) {
140170
return loadQuantizedParams(uri)
141171
}
142172

143173
protected extractParams(weights: Float32Array) {
144-
return extractParams(weights)
174+
return extractParams(weights, this.hasSeparableConvs)
145175
}
146176
}

src/tinyYolov2/config.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ export const BOX_ANCHORS = [
1010
new Point(4.30971, 7.04493),
1111
new Point(10.246, 4.59428),
1212
new Point(12.6868, 11.8741)
13+
]
14+
15+
export const BOX_ANCHORS_SEPARABLE = [
16+
new Point(1.603231, 2.094468),
17+
new Point(6.041143, 7.080126),
18+
new Point(2.882459, 3.518061),
19+
new Point(4.266906, 5.178857),
20+
new Point(9.041765, 10.66308)
1321
]

src/tinyYolov2/convWithBatchNorm.ts

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { leaky } from './leaky';
4-
import { ConvWithBatchNorm } from './types';
4+
import { ConvWithBatchNorm, SeparableConvParams } from './types';
55

6-
export function convWithBatchNorm(x: tf.Tensor4D, params: ConvWithBatchNorm): tf.Tensor4D {
6+
export function convWithBatchNorm(x: tf.Tensor4D, params: ConvWithBatchNorm | SeparableConvParams): tf.Tensor4D {
77
return tf.tidy(() => {
88
let out = tf.pad(x, [[0, 0], [1, 1], [1, 1], [0, 0]]) as tf.Tensor4D
9-
out = tf.conv2d(out, params.conv.filters, [1, 1], 'valid')
10-
out = tf.sub(out, params.bn.sub)
11-
out = tf.mul(out, params.bn.truediv)
12-
out = tf.add(out, params.conv.bias)
9+
10+
if (params instanceof SeparableConvParams) {
11+
out = tf.separableConv2d(out, params.depthwise_filter, params.pointwise_filter, [1, 1], 'same')
12+
out = tf.add(out, params.bias)
13+
} else {
14+
out = tf.conv2d(out, params.conv.filters, [1, 1], 'valid')
15+
out = tf.sub(out, params.bn.sub)
16+
out = tf.mul(out, params.bn.truediv)
17+
out = tf.add(out, params.conv.bias)
18+
}
19+
1320
return leaky(out)
1421
})
1522
}

src/tinyYolov2/extractParams.ts

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import * as tf from '@tensorflow/tfjs-core';
33
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
44
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
55
import { ExtractWeightsFunction, ParamMapping } from '../commons/types';
6-
import { BatchNorm, ConvWithBatchNorm, NetParams } from './types';
6+
import { BatchNorm, ConvWithBatchNorm, NetParams, SeparableConvParams } from './types';
77

88
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
99

@@ -30,14 +30,34 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings
3030
return { conv, bn }
3131
}
3232

33+
function extractSeparableConvParams(channelsIn: number, channelsOut: number, mappedPrefix: string): SeparableConvParams {
34+
console.log(mappedPrefix)
35+
const depthwise_filter = tf.tensor4d(extractWeights(3 * 3 * channelsIn), [3, 3, channelsIn, 1])
36+
const pointwise_filter = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
37+
const bias = tf.tensor1d(extractWeights(channelsOut))
38+
console.log('done')
39+
paramMappings.push(
40+
{ paramPath: `${mappedPrefix}/depthwise_filter` },
41+
{ paramPath: `${mappedPrefix}/pointwise_filter` },
42+
{ paramPath: `${mappedPrefix}/bias` }
43+
)
44+
45+
return new SeparableConvParams(
46+
depthwise_filter,
47+
pointwise_filter,
48+
bias
49+
)
50+
}
51+
3352
return {
3453
extractConvParams,
35-
extractConvWithBatchNormParams
54+
extractConvWithBatchNormParams,
55+
extractSeparableConvParams
3656
}
3757

3858
}
3959

40-
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
60+
export function extractParams(weights: Float32Array, hasSeparableConvs: boolean): { params: NetParams, paramMappings: ParamMapping[] } {
4161

4262
const {
4363
extractWeights,
@@ -48,18 +68,22 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
4868

4969
const {
5070
extractConvParams,
51-
extractConvWithBatchNormParams
71+
extractConvWithBatchNormParams,
72+
extractSeparableConvParams
5273
} = extractorsFactory(extractWeights, paramMappings)
5374

54-
const conv0 = extractConvWithBatchNormParams(3, 16, 'conv0')
55-
const conv1 = extractConvWithBatchNormParams(16, 32, 'conv1')
56-
const conv2 = extractConvWithBatchNormParams(32, 64, 'conv2')
57-
const conv3 = extractConvWithBatchNormParams(64, 128, 'conv3')
58-
const conv4 = extractConvWithBatchNormParams(128, 256, 'conv4')
59-
const conv5 = extractConvWithBatchNormParams(256, 512, 'conv5')
60-
const conv6 = extractConvWithBatchNormParams(512, 1024, 'conv6')
61-
const conv7 = extractConvWithBatchNormParams(1024, 1024, 'conv7')
62-
const conv8 = extractConvParams(1024, 30, 1, 'conv8')
75+
const extractConvFn = hasSeparableConvs ? extractSeparableConvParams : extractConvWithBatchNormParams
76+
const numAnchorEncodings = hasSeparableConvs ? 5 : 6
77+
78+
const conv0 = extractConvFn(3, 16, 'conv0',)
79+
const conv1 = extractConvFn(16, 32, 'conv1')
80+
const conv2 = extractConvFn(32, 64, 'conv2')
81+
const conv3 = extractConvFn(64, 128, 'conv3')
82+
const conv4 = extractConvFn(128, 256, 'conv4')
83+
const conv5 = extractConvFn(256, 512, 'conv5')
84+
const conv6 = extractConvFn(512, 1024, 'conv6')
85+
const conv7 = extractConvFn(1024, 1024, 'conv7')
86+
const conv8 = extractConvParams(1024, 5 * numAnchorEncodings, 1, 'conv8')
6387

6488
if (getRemainingWeights().length !== 0) {
6589
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)

0 commit comments

Comments
 (0)