Skip to content

Commit 8490918

Browse files
weight loading of quantized tiny-yolov2-seperable-conv2d + use separable conv model by default + testcases
1 parent 4f45297 commit 8490918

File tree

7 files changed

+221
-107
lines changed

7 files changed

+221
-107
lines changed

src/tinyYolov2/TinyYolov2.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ import { NetParams, PostProcessingParams, TinyYolov2ForwardParams } from './type
2020

2121
export class TinyYolov2 extends NeuralNetwork<NetParams> {
2222

23-
private _hasSeparableConvs: boolean
23+
private _withSeparableConvs: boolean
2424
private _anchors: Point[]
2525

26-
constructor(hasSeparableConvs: boolean = false) {
26+
constructor(withSeparableConvs: boolean = true) {
2727
super('TinyYolov2')
28-
this._hasSeparableConvs = hasSeparableConvs
29-
this._anchors = hasSeparableConvs ? BOX_ANCHORS_SEPARABLE : BOX_ANCHORS
28+
this._withSeparableConvs = withSeparableConvs
29+
this._anchors = withSeparableConvs ? BOX_ANCHORS_SEPARABLE : BOX_ANCHORS
3030
}
3131

32-
public get hasSeparableConvs(): boolean {
33-
return this._hasSeparableConvs
32+
public get withSeparableConvs(): boolean {
33+
return this._withSeparableConvs
3434
}
3535

3636
public get anchors(): Point[] {
@@ -48,7 +48,7 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
4848
const out = tf.tidy(() => {
4949

5050
let batchTensor = input.toBatchTensor(inputSize, false)
51-
batchTensor = this.hasSeparableConvs
51+
batchTensor = this.withSeparableConvs
5252
? normalize(batchTensor, MEAN_RGB)
5353
: batchTensor
5454
batchTensor = batchTensor.div(tf.scalar(256)) as tf.Tensor4D
@@ -132,7 +132,7 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
132132
const numCells = outputTensor.shape[1]
133133

134134
const [boxesTensor, scoresTensor] = tf.tidy(() => {
135-
const reshaped = outputTensor.reshape([numCells, numCells, NUM_BOXES, this.hasSeparableConvs ? 5 : 6])
135+
const reshaped = outputTensor.reshape([numCells, numCells, NUM_BOXES, this.withSeparableConvs ? 5 : 6])
136136

137137
const boxes = reshaped.slice([0, 0, 0, 0], [numCells, numCells, NUM_BOXES, 4])
138138
const scores = reshaped.slice([0, 0, 0, 4], [numCells, numCells, NUM_BOXES, 1])
@@ -172,10 +172,10 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
172172
}
173173

174174
protected loadQuantizedParams(uri: string | undefined) {
175-
return loadQuantizedParams(uri)
175+
return loadQuantizedParams(uri, this.withSeparableConvs)
176176
}
177177

178178
protected extractParams(weights: Float32Array) {
179-
return extractParams(weights, this.hasSeparableConvs)
179+
return extractParams(weights, this.withSeparableConvs)
180180
}
181181
}

src/tinyYolov2/extractParams.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings
5656

5757
}
5858

59-
export function extractParams(weights: Float32Array, hasSeparableConvs: boolean): { params: NetParams, paramMappings: ParamMapping[] } {
59+
export function extractParams(weights: Float32Array, withSeparableConvs: boolean): { params: NetParams, paramMappings: ParamMapping[] } {
6060

6161
const {
6262
extractWeights,
@@ -71,8 +71,8 @@ export function extractParams(weights: Float32Array, hasSeparableConvs: boolean)
7171
extractSeparableConvParams
7272
} = extractorsFactory(extractWeights, paramMappings)
7373

74-
const extractConvFn = hasSeparableConvs ? extractSeparableConvParams : extractConvWithBatchNormParams
75-
const numAnchorEncodings = hasSeparableConvs ? 5 : 6
74+
const extractConvFn = withSeparableConvs ? extractSeparableConvParams : extractConvWithBatchNormParams
75+
const numAnchorEncodings = withSeparableConvs ? 5 : 6
7676

7777
const conv0 = extractConvFn(3, 16, 'conv0',)
7878
const conv1 = extractConvFn(16, 32, 'conv1')

src/tinyYolov2/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import { TinyYolov2 } from './TinyYolov2';
22

33
export * from './TinyYolov2';
44

5-
export function createTinyYolov2(weights: Float32Array) {
6-
const net = new TinyYolov2()
5+
export function createTinyYolov2(weights: Float32Array, withSeparableConvs: boolean = true) {
6+
const net = new TinyYolov2(withSeparableConvs)
77
net.extractWeights(weights)
88
return net
99
}

src/tinyYolov2/loadQuantizedParams.ts

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensor
44
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
55
import { loadWeightMap } from '../commons/loadWeightMap';
66
import { ConvParams, ParamMapping } from '../commons/types';
7-
import { BatchNorm, ConvWithBatchNorm, NetParams } from './types';
7+
import { BatchNorm, ConvWithBatchNorm, NetParams, SeparableConvParams } from './types';
88

99
const DEFAULT_MODEL_NAME = 'tiny_yolov2_model'
10+
const DEFAULT_MODEL_NAME_SEPARABLE_CONV = 'tiny_yolov2_separable_conv_model'
1011

1112
function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
1213

@@ -30,35 +31,51 @@ function extractorsFactory(weightMap: any, paramMappings: ParamMapping[]) {
3031
return { conv, bn }
3132
}
3233

34+
function extractSeparableConvParams(prefix: string): SeparableConvParams {
35+
const depthwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/depthwise_filter`, 4)
36+
const pointwise_filter = extractWeightEntry<tf.Tensor4D>(`${prefix}/pointwise_filter`, 4)
37+
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
38+
39+
return new SeparableConvParams(
40+
depthwise_filter,
41+
pointwise_filter,
42+
bias
43+
)
44+
}
45+
3346
return {
3447
extractConvParams,
35-
extractConvWithBatchNormParams
48+
extractConvWithBatchNormParams,
49+
extractSeparableConvParams
3650
}
3751

3852
}
3953

4054
export async function loadQuantizedParams(
41-
uri: string | undefined
55+
uri: string | undefined,
56+
withSeparableConvs: boolean
4257
): Promise<{ params: NetParams, paramMappings: ParamMapping[] }> {
4358

44-
const weightMap = await loadWeightMap(uri, DEFAULT_MODEL_NAME)
59+
const weightMap = await loadWeightMap(uri, withSeparableConvs ? DEFAULT_MODEL_NAME_SEPARABLE_CONV : DEFAULT_MODEL_NAME)
4560
const paramMappings: ParamMapping[] = []
4661

4762
const {
4863
extractConvParams,
49-
extractConvWithBatchNormParams
64+
extractConvWithBatchNormParams,
65+
extractSeparableConvParams
5066
} = extractorsFactory(weightMap, paramMappings)
5167

68+
const extractConvFn = withSeparableConvs ? extractSeparableConvParams : extractConvWithBatchNormParams
5269

5370
const params = {
54-
conv0: extractConvWithBatchNormParams('conv0'),
55-
conv1: extractConvWithBatchNormParams('conv1'),
56-
conv2: extractConvWithBatchNormParams('conv2'),
57-
conv3: extractConvWithBatchNormParams('conv3'),
58-
conv4: extractConvWithBatchNormParams('conv4'),
59-
conv5: extractConvWithBatchNormParams('conv5'),
60-
conv6: extractConvWithBatchNormParams('conv6'),
61-
conv7: extractConvWithBatchNormParams('conv7'),
71+
conv0: extractConvFn('conv0'),
72+
conv1: extractConvFn('conv1'),
73+
conv2: extractConvFn('conv2'),
74+
conv3: extractConvFn('conv3'),
75+
conv4: extractConvFn('conv4'),
76+
conv5: extractConvFn('conv5'),
77+
conv6: extractConvFn('conv6'),
78+
conv7: extractConvFn('conv7'),
6279
conv8: extractConvParams('conv8')
6380
}
6481

test/tests/e2e/expectedResults.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ export const expectedTinyYolov2Boxes = [
2929
{ x: 87, y: 30, width: 92, height: 93 }
3030
]
3131

32+
export const expectedTinyYolov2SeparableConvBoxes = [
33+
{ x: 42, y: 257, width: 111, height: 121 },
34+
{ x: 454, y: 175, width: 104, height: 121 },
35+
{ x: 230, y: 45, width: 94, height: 104 },
36+
{ x: 574, y: 62, width: 88, height: 113 },
37+
{ x: 260, y: 233, width: 82, height: 104 },
38+
{ x: 83, y: 24, width: 85, height: 111 }
39+
]
40+
3241
export const expectedMtcnnFaceLandmarks = [
3342
[new Point(117, 58), new Point(156, 63), new Point(141, 86), new Point(109, 98), new Point(147, 104)],
3443
[new Point(82, 292), new Point(134, 304), new Point(104, 330), new Point(72, 342), new Point(120, 353)],
@@ -38,7 +47,6 @@ export const expectedMtcnnFaceLandmarks = [
3847
[new Point(489, 224), new Point(534, 223), new Point(507, 250), new Point(493, 271), new Point(530, 270)]
3948
]
4049

41-
4250
export function expectMtcnnResults(
4351
results: { faceDetection: faceapi.FaceDetection, faceLandmarks: faceapi.FaceLandmarks5 }[],
4452
boxOrder: number[],

0 commit comments

Comments
 (0)