Skip to content

Commit 886cf99

Browse files
init tiny yolov2 + weight loading
1 parent 9ef6351 commit 886cf99

File tree

8 files changed

+194
-0
lines changed

8 files changed

+194
-0
lines changed

src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ export * from './faceRecognitionNet';
2121
export * from './globalApi';
2222
export * from './mtcnn';
2323
export * from './padToSquare';
24+
export * from './tinyYolov2';
2425
export * from './toNetInput';
2526
export * from './utils'

src/tinyYolov2/TinyYolov2.ts

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { convLayer } from '../commons/convLayer';
4+
import { NeuralNetwork } from '../commons/NeuralNetwork';
5+
import { NetInput } from '../NetInput';
6+
import { toNetInput } from '../toNetInput';
7+
import { TNetInput } from '../types';
8+
import { convWithBatchNorm } from './convWithBatchNorm';
9+
import { extractParams } from './extractParams';
10+
import { NetParams } from './types';
11+
12+
13+
export class TinyYolov2 extends NeuralNetwork<NetParams> {
14+
15+
constructor() {
16+
super('TinyYolov2')
17+
}
18+
19+
public async forwardInput(input: NetInput): Promise<any> {
20+
21+
const { params } = this
22+
23+
if (!params) {
24+
throw new Error('TinyYolov2 - load model before inference')
25+
}
26+
27+
const out = tf.tidy(() => {
28+
const batchTensor = input.toBatchTensor(416).div(tf.scalar(255)).toFloat()
29+
30+
let out = tf.pad(batchTensor, [[0, 0], [1, 1], [1, 1], [0, 0]]) as tf.Tensor4D
31+
32+
out = convWithBatchNorm(out, params.conv0)
33+
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
34+
out = convWithBatchNorm(out, params.conv1)
35+
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
36+
out = convWithBatchNorm(out, params.conv2)
37+
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
38+
out = convWithBatchNorm(out, params.conv3)
39+
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
40+
out = convWithBatchNorm(out, params.conv4)
41+
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
42+
out = convWithBatchNorm(out, params.conv5)
43+
out = tf.maxPool(out, [2, 2], [1, 1], 'valid')
44+
out = convWithBatchNorm(out, params.conv6)
45+
out = convWithBatchNorm(out, params.conv7)
46+
out = convLayer(out, params.conv8, 'valid', false)
47+
48+
return out
49+
})
50+
51+
return out
52+
}
53+
54+
public async forward(input: TNetInput): Promise<any> {
55+
return await this.forwardInput(await toNetInput(input, true, true))
56+
}
57+
58+
/* TODO
59+
protected loadQuantizedParams(uri: string | undefined) {
60+
return loadQuantizedParams(uri)
61+
}
62+
*/
63+
64+
protected extractParams(weights: Float32Array) {
65+
return extractParams(weights)
66+
}
67+
}

src/tinyYolov2/convWithBatchNorm.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { leaky } from './leaky';
4+
import { ConvWithBatchNorm } from './types';
5+
6+
export function convWithBatchNorm(x: tf.Tensor4D, params: ConvWithBatchNorm): tf.Tensor4D {
7+
return tf.tidy(() => {
8+
let out = tf.conv2d(x, params.conv.filters, [1, 1], 'valid')
9+
out = tf.sub(out, params.bn.sub)
10+
out = tf.mul(out, params.bn.truediv)
11+
out = tf.add(out, params.conv.bias)
12+
return leaky(out)
13+
})
14+
}

src/tinyYolov2/extractParams.ts

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
4+
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
5+
import { ExtractWeightsFunction, ParamMapping } from '../commons/types';
6+
import { BatchNorm, ConvWithBatchNorm, NetParams } from './types';
7+
8+
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
9+
10+
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
11+
12+
function extractBatchNormParams(size: number, mappedPrefix: string): BatchNorm {
13+
14+
const sub = tf.tensor1d(extractWeights(size))
15+
const truediv = tf.tensor1d(extractWeights(size))
16+
17+
paramMappings.push(
18+
{ paramPath: `${mappedPrefix}/sub` },
19+
{ paramPath: `${mappedPrefix}/truediv` }
20+
)
21+
22+
return { sub, truediv }
23+
}
24+
25+
function extractConvWithBatchNormParams(channelsIn: number, channelsOut: number, mappedPrefix: string): ConvWithBatchNorm {
26+
27+
const conv = extractConvParams(channelsIn, channelsOut, 3, `${mappedPrefix}/conv`)
28+
const bn = extractBatchNormParams(channelsOut, `${mappedPrefix}/bn`)
29+
30+
return { conv, bn }
31+
}
32+
33+
return {
34+
extractConvParams,
35+
extractConvWithBatchNormParams
36+
}
37+
38+
}
39+
40+
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
41+
42+
const {
43+
extractWeights,
44+
getRemainingWeights
45+
} = extractWeightsFactory(weights)
46+
47+
const paramMappings: ParamMapping[] = []
48+
49+
const {
50+
extractConvParams,
51+
extractConvWithBatchNormParams
52+
} = extractorsFactory(extractWeights, paramMappings)
53+
54+
const conv0 = extractConvWithBatchNormParams(3, 16, 'conv0')
55+
const conv1 = extractConvWithBatchNormParams(16, 32, 'conv1')
56+
const conv2 = extractConvWithBatchNormParams(32, 64, 'conv2')
57+
const conv3 = extractConvWithBatchNormParams(64, 128, 'conv3')
58+
const conv4 = extractConvWithBatchNormParams(128, 256, 'conv4')
59+
const conv5 = extractConvWithBatchNormParams(256, 512, 'conv5')
60+
const conv6 = extractConvWithBatchNormParams(512, 1024, 'conv6')
61+
const conv7 = extractConvWithBatchNormParams(1024, 1024, 'conv7')
62+
const conv8 = extractConvParams(1024, 30, 1, 'conv8')
63+
64+
if (getRemainingWeights().length !== 0) {
65+
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
66+
}
67+
68+
const params = { conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8 }
69+
70+
return { params, paramMappings }
71+
}

src/tinyYolov2/index.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { TinyYolov2 } from './TinyYolov2';
2+
3+
export * from './TinyYolov2';
4+
5+
export function createTinyYolov2(weights: Float32Array) {
6+
const net = new TinyYolov2()
7+
net.extractWeights(weights)
8+
return net
9+
}

src/tinyYolov2/leaky.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
export function leaky(x: tf.Tensor4D): tf.Tensor4D {
4+
return tf.tidy(() => {
5+
return tf.maximum(x, tf.mul(x, tf.scalar(0.10000000149011612)))
6+
})
7+
}

src/tinyYolov2/types.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ConvParams } from '../commons/types';
4+
5+
export type BatchNorm = {
6+
sub: tf.Tensor1D
7+
truediv: tf.Tensor1D
8+
}
9+
10+
export type ConvWithBatchNorm = {
11+
conv: ConvParams
12+
bn: BatchNorm
13+
}
14+
15+
export type NetParams = {
16+
conv0: ConvWithBatchNorm
17+
conv1: ConvWithBatchNorm
18+
conv2: ConvWithBatchNorm
19+
conv3: ConvWithBatchNorm
20+
conv4: ConvWithBatchNorm
21+
conv5: ConvWithBatchNorm
22+
conv6: ConvWithBatchNorm
23+
conv7: ConvWithBatchNorm
24+
conv8: ConvParams
25+
}
60.1 MB
Binary file not shown.

0 commit comments

Comments
 (0)