Skip to content

Commit f3469d0

Browse files
prediction layer + params
1 parent 8ad95e3 commit f3469d0

File tree

6 files changed

+157
-65
lines changed

6 files changed

+157
-65
lines changed

src/faceDetectionNet/boxPredictionLayer.ts

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,16 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { FaceDetectionNet } from './types';
44

5-
function boxEncodingPredictionLayer(
5+
function convWithBias(
66
x: tf.Tensor4D,
77
params: FaceDetectionNet.ConvWithBiasParams
88
) {
9-
return tf.tidy(() => {
10-
11-
// TODO
12-
return x
13-
14-
})
15-
}
16-
17-
function classPredictionLayer(
18-
x: tf.Tensor4D,
19-
params: FaceDetectionNet.ConvWithBiasParams
20-
) {
21-
return tf.tidy(() => {
22-
23-
// TODO
24-
return x
25-
26-
})
9+
return tf.tidy(() =>
10+
tf.add(
11+
tf.conv2d(x, params.filters, [1, 1], 'same'),
12+
params.bias
13+
)
14+
)
2715
}
2816

2917
export function boxPredictionLayer(
@@ -33,13 +21,15 @@ export function boxPredictionLayer(
3321
) {
3422
return tf.tidy(() => {
3523

24+
const batchSize = x.shape[0]
25+
3626
const boxPredictionEncoding = tf.reshape(
37-
boxEncodingPredictionLayer(x, params.box_encoding_predictor_params),
38-
[x.shape[0], size, 1, 4]
27+
convWithBias(x, params.box_encoding_predictor_params),
28+
[batchSize, size, 1, 4]
3929
)
4030
const classPrediction = tf.reshape(
41-
classPredictionLayer(x, params.class_predictor_params),
42-
[x.shape[0], size, 3]
31+
convWithBias(x, params.class_predictor_params),
32+
[batchSize, size, 3]
4333
)
4434

4535
return {

src/faceDetectionNet/extractParams.ts

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,42 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
2020
}
2121
}
2222

23-
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.PointwiseConvParams {
24-
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
25-
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut))
23+
function extractConvWithBiasParams(
24+
channelsIn: number,
25+
channelsOut: number,
26+
filterSize: number
27+
): FaceDetectionNet.ConvWithBiasParams {
28+
const filters = tf.tensor4d(
29+
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
30+
[filterSize, filterSize, channelsIn, channelsOut]
31+
)
32+
const bias = tf.tensor1d(extractWeights(channelsOut))
2633

2734
return {
2835
filters,
29-
batch_norm_offset
36+
bias
37+
}
38+
}
39+
40+
function extractPointwiseConvParams(
41+
channelsIn: number,
42+
channelsOut: number,
43+
filterSize: number
44+
): FaceDetectionNet.PointwiseConvParams {
45+
const {
46+
filters,
47+
bias
48+
} = extractConvWithBiasParams(channelsIn, channelsOut, filterSize)
49+
50+
return {
51+
filters,
52+
batch_norm_offset: bias
3053
}
3154
}
3255

3356
function extractConvPairParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.ConvPairParams {
3457
const depthwise_conv_params = extractDepthwiseConvParams(channelsIn)
35-
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut)
58+
const pointwise_conv_params = extractPointwiseConvParams(channelsIn, channelsOut, 1)
3659

3760
return {
3861
depthwise_conv_params,
@@ -42,11 +65,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
4265

4366
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
4467

45-
const conv_0_params = {
46-
filters: tf.tensor4d(extractWeights(3 * 3 * 3 * 32), [3, 3, 3, 32]),
47-
batch_norm_offset: tf.tensor1d(extractWeights(32))
48-
49-
}
68+
const conv_0_params = extractPointwiseConvParams(3, 32, 3)
5069

5170
const channelNumPairs = [
5271
[32, 64],
@@ -75,32 +94,101 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
7594

7695
}
7796

97+
function extractPredictionLayerParams(): FaceDetectionNet.PredictionParams {
98+
const conv_0_params = extractPointwiseConvParams(1024, 256, 1)
99+
const conv_1_params = extractPointwiseConvParams(256, 512, 3)
100+
const conv_2_params = extractPointwiseConvParams(512, 128, 1)
101+
const conv_3_params = extractPointwiseConvParams(128, 256, 3)
102+
const conv_4_params = extractPointwiseConvParams(256, 128, 1)
103+
const conv_5_params = extractPointwiseConvParams(128, 256, 3)
104+
const conv_6_params = extractPointwiseConvParams(256, 64, 1)
105+
const conv_7_params = extractPointwiseConvParams(64, 128, 3)
106+
107+
const box_encoding_0_predictor_params = extractConvWithBiasParams(512, 12, 1)
108+
const class_predictor_0_params = extractConvWithBiasParams(512, 9, 1)
109+
const box_encoding_1_predictor_params = extractConvWithBiasParams(1024, 24, 1)
110+
const class_predictor_1_params = extractConvWithBiasParams(1024, 18, 1)
111+
const box_encoding_2_predictor_params = extractConvWithBiasParams(512, 24, 1)
112+
const class_predictor_2_params = extractConvWithBiasParams(512, 18, 1)
113+
const box_encoding_3_predictor_params = extractConvWithBiasParams(256, 24, 1)
114+
const class_predictor_3_params = extractConvWithBiasParams(256, 18, 1)
115+
const box_encoding_4_predictor_params = extractConvWithBiasParams(256, 24, 1)
116+
const class_predictor_4_params = extractConvWithBiasParams(256, 18, 1)
117+
const box_encoding_5_predictor_params = extractConvWithBiasParams(128, 24, 1)
118+
const class_predictor_5_params = extractConvWithBiasParams(128, 18, 1)
119+
120+
const box_predictor_0_params = {
121+
box_encoding_predictor_params: box_encoding_0_predictor_params,
122+
class_predictor_params: class_predictor_0_params
123+
}
124+
const box_predictor_1_params = {
125+
box_encoding_predictor_params: box_encoding_1_predictor_params,
126+
class_predictor_params: class_predictor_1_params
127+
}
128+
const box_predictor_2_params = {
129+
box_encoding_predictor_params: box_encoding_2_predictor_params,
130+
class_predictor_params: class_predictor_2_params
131+
}
132+
const box_predictor_3_params = {
133+
box_encoding_predictor_params: box_encoding_3_predictor_params,
134+
class_predictor_params: class_predictor_3_params
135+
}
136+
const box_predictor_4_params = {
137+
box_encoding_predictor_params: box_encoding_4_predictor_params,
138+
class_predictor_params: class_predictor_4_params
139+
}
140+
const box_predictor_5_params = {
141+
box_encoding_predictor_params: box_encoding_5_predictor_params,
142+
class_predictor_params: class_predictor_5_params
143+
}
144+
145+
return {
146+
conv_0_params,
147+
conv_1_params,
148+
conv_2_params,
149+
conv_3_params,
150+
conv_4_params,
151+
conv_5_params,
152+
conv_6_params,
153+
conv_7_params,
154+
box_predictor_0_params,
155+
box_predictor_1_params,
156+
box_predictor_2_params,
157+
box_predictor_3_params,
158+
box_predictor_4_params,
159+
box_predictor_5_params
160+
}
161+
}
162+
78163

79164
return {
80-
extractMobilenetV1Params
165+
extractMobilenetV1Params,
166+
extractPredictionLayerParams
81167
}
82168

83169
}
84170

85171
export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams {
86172
const extractWeights = (numWeights: number): Float32Array => {
87-
console.log(numWeights)
88173
const ret = weights.slice(0, numWeights)
89174
weights = weights.slice(numWeights)
90175
return ret
91176
}
92177

93178
const {
94-
extractMobilenetV1Params
179+
extractMobilenetV1Params,
180+
extractPredictionLayerParams
95181
} = extractorsFactory(extractWeights)
96182

97183
const mobilenetv1_params = extractMobilenetV1Params()
184+
const prediction_layer_params = extractPredictionLayerParams()
98185

99186
if (weights.length !== 0) {
100187
throw new Error(`weights remaing after extract: ${weights.length}`)
101188
}
102189

103190
return {
104-
mobilenetv1_params
191+
mobilenetv1_params,
192+
prediction_layer_params
105193
}
106194
}

src/faceDetectionNet/index.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { isFloat } from '../utils';
44
import { extractParams } from './extractParams';
55
import { mobileNetV1 } from './mobileNetV1';
66
import { resizeLayer } from './resizeLayer';
7+
import { predictionLayer } from './predictionLayer';
78

89
function fromData(input: number[]): tf.Tensor4D {
910
const pxPerChannel = input.length / 3
@@ -47,19 +48,18 @@ export function faceDetectionNet(weights: Float32Array) {
4748
? fromImageData(imgDataArray)
4849
: fromData(input as number[])
4950

50-
let out = resizeLayer(imgTensor) as tf.Tensor4D
51-
out = mobileNetV1(out, params.mobilenetv1_params)
51+
const resized = resizeLayer(imgTensor) as tf.Tensor4D
52+
const features = mobileNetV1(resized, params.mobilenetv1_params)
5253

54+
const {
55+
boxPredictions,
56+
classPredictions
57+
} = predictionLayer(features.out, features.conv11, params.prediction_layer_params)
5358

54-
55-
// boxpredictor0: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
56-
// boxpredictor1: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
57-
// boxpredictor2: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_512/Relu6
58-
// boxpredictor3: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256/Relu6
59-
// boxpredictor4: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_4_3x3_s2_256/Relu6
60-
// boxpredictor5: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6
61-
62-
return out
59+
return {
60+
boxPredictions,
61+
classPredictions
62+
}
6363

6464
})
6565
}

src/faceDetectionNet/mobileNetV1.ts

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,33 @@ function depthwiseConvLayer(
2626
})
2727
}
2828

29-
30-
3129
function getStridesForLayerIdx(layerIdx: number): [number, number] {
3230
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
3331
}
3432

3533
export function mobileNetV1(x: tf.Tensor4D, params: FaceDetectionNet.MobileNetV1.Params) {
3634
return tf.tidy(() => {
3735

36+
let conv11 = null
3837
let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2])
3938
params.conv_pair_params.forEach((param, i) => {
40-
const depthwiseConvStrides = getStridesForLayerIdx(i + 1)
39+
const layerIdx = i + 1
40+
const depthwiseConvStrides = getStridesForLayerIdx(layerIdx)
4141
out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides)
4242
out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1])
43+
if (layerIdx === 11) {
44+
conv11 = out
45+
}
4346
})
44-
return out
47+
48+
if (conv11 === null) {
49+
throw new Error('mobileNetV1 - output of conv layer 11 is null')
50+
}
51+
52+
return {
53+
out,
54+
conv11: conv11 as any
55+
}
4556

4657
})
4758
}

src/faceDetectionNet/predictionLayer.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@ import { boxPredictionLayer } from './boxPredictionLayer';
44
import { pointwiseConvLayer } from './pointwiseConvLayer';
55
import { FaceDetectionNet } from './types';
66

7-
export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
7+
export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
88
return tf.tidy(() => {
99

1010
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
11-
const conv1 = pointwiseConvLayer(x, params.conv_1_params, [2, 2])
12-
const conv2 = pointwiseConvLayer(x, params.conv_2_params, [1, 1])
13-
const conv3 = pointwiseConvLayer(x, params.conv_3_params, [2, 2])
14-
const conv4 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
15-
const conv5 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
16-
const conv6 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
17-
const conv7 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
11+
const conv1 = pointwiseConvLayer(conv0, params.conv_1_params, [2, 2])
12+
const conv2 = pointwiseConvLayer(conv1, params.conv_2_params, [1, 1])
13+
const conv3 = pointwiseConvLayer(conv2, params.conv_3_params, [2, 2])
14+
const conv4 = pointwiseConvLayer(conv3, params.conv_4_params, [1, 1])
15+
const conv5 = pointwiseConvLayer(conv4, params.conv_5_params, [2, 2])
16+
const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1])
17+
const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2])
1818

19-
const boxPrediction0 = boxPredictionLayer(x, params.box_predictor_0_params, 3072)
19+
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params, 3072)
2020
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
2121
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
2222
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
@@ -30,7 +30,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
3030
boxPrediction3.boxPredictionEncoding,
3131
boxPrediction4.boxPredictionEncoding,
3232
boxPrediction5.boxPredictionEncoding
33-
])
33+
], 1)
3434

3535
const classPredictions = tf.concat([
3636
boxPrediction0.classPrediction,
@@ -39,7 +39,7 @@ export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.Predict
3939
boxPrediction3.classPrediction,
4040
boxPrediction4.classPrediction,
4141
boxPrediction5.classPrediction
42-
])
42+
], 1)
4343

4444
return {
4545
boxPredictions,

src/faceDetectionNet/types.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ export namespace FaceDetectionNet {
3535
}
3636

3737
export type BoxPredictionParams = {
38-
class_predictor_params: ConvWithBiasParams
3938
box_encoding_predictor_params: ConvWithBiasParams
39+
class_predictor_params: ConvWithBiasParams
4040
}
4141

4242
export type PredictionParams = {
@@ -46,6 +46,8 @@ export namespace FaceDetectionNet {
4646
conv_3_params: PointwiseConvParams
4747
conv_4_params: PointwiseConvParams
4848
conv_5_params: PointwiseConvParams
49+
conv_6_params: PointwiseConvParams
50+
conv_7_params: PointwiseConvParams
4951
box_predictor_0_params: BoxPredictionParams
5052
box_predictor_1_params: BoxPredictionParams
5153
box_predictor_2_params: BoxPredictionParams
@@ -55,7 +57,8 @@ export namespace FaceDetectionNet {
5557
}
5658

5759
export type NetParams = {
58-
mobilenetv1_params: MobileNetV1.Params
60+
mobilenetv1_params: MobileNetV1.Params,
61+
prediction_layer_params: PredictionParams
5962
}
6063

6164
}

0 commit comments

Comments
 (0)