Skip to content

Commit 8ad95e3

Browse files
init box prediction
1 parent 2efff4f commit 8ad95e3

File tree

7 files changed

+158
-35
lines changed

7 files changed

+158
-35
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { FaceDetectionNet } from './types';
4+
5+
function boxEncodingPredictionLayer(
6+
x: tf.Tensor4D,
7+
params: FaceDetectionNet.ConvWithBiasParams
8+
) {
9+
return tf.tidy(() => {
10+
11+
// TODO
12+
return x
13+
14+
})
15+
}
16+
17+
function classPredictionLayer(
18+
x: tf.Tensor4D,
19+
params: FaceDetectionNet.ConvWithBiasParams
20+
) {
21+
return tf.tidy(() => {
22+
23+
// TODO
24+
return x
25+
26+
})
27+
}
28+
29+
export function boxPredictionLayer(
30+
x: tf.Tensor4D,
31+
params: FaceDetectionNet.BoxPredictionParams,
32+
size: number
33+
) {
34+
return tf.tidy(() => {
35+
36+
const boxPredictionEncoding = tf.reshape(
37+
boxEncodingPredictionLayer(x, params.box_encoding_predictor_params),
38+
[x.shape[0], size, 1, 4]
39+
)
40+
const classPrediction = tf.reshape(
41+
classPredictionLayer(x, params.class_predictor_params),
42+
[x.shape[0], size, 3]
43+
)
44+
45+
return {
46+
boxPredictionEncoding,
47+
classPrediction
48+
}
49+
})
50+
}

src/faceDetectionNet/extractParams.ts

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { FaceDetectionNet } from './types';
44

5-
function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
5+
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
66

77
function extractDepthwiseConvParams(numChannels: number): FaceDetectionNet.MobileNetV1.DepthwiseConvParams {
88
const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1])
@@ -20,7 +20,7 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number
2020
}
2121
}
2222

23-
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.PointwiseConvParams {
23+
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.PointwiseConvParams {
2424
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
2525
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut))
2626

@@ -40,22 +40,6 @@ function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number
4040
}
4141
}
4242

43-
return {
44-
extractPointwiseConvParams,
45-
extractConvPairParams
46-
}
47-
48-
}
49-
50-
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
51-
52-
const {
53-
extractPointwiseConvParams,
54-
extractConvPairParams
55-
} = mobilenetV1WeightsExtractorsFactory(extractWeights)
56-
57-
58-
5943
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
6044

6145
const conv_0_params = {

src/faceDetectionNet/index.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ export function faceDetectionNet(weights: Float32Array) {
5050
let out = resizeLayer(imgTensor) as tf.Tensor4D
5151
out = mobileNetV1(out, params.mobilenetv1_params)
5252

53+
54+
55+
// boxpredictor0: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
56+
// boxpredictor1: FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6
57+
// boxpredictor2: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_512/Relu6
58+
// boxpredictor3: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256/Relu6
59+
// boxpredictor4: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_4_3x3_s2_256/Relu6
60+
// boxpredictor5: FeatureExtractor/MobilenetV1/Conv2d_13_pointwise_2_Conv2d_5_3x3_s2_128/Relu6
61+
5362
return out
5463

5564
})

src/faceDetectionNet/mobileNetV1.ts

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { pointwiseConvLayer } from './pointwiseConvLayer';
34
import { FaceDetectionNet } from './types';
45

56
const epsilon = 0.0010000000474974513
@@ -25,19 +26,7 @@ function depthwiseConvLayer(
2526
})
2627
}
2728

28-
function pointwiseConvLayer(
29-
x: tf.Tensor4D,
30-
params: FaceDetectionNet.MobileNetV1.PointwiseConvParams,
31-
strides: [number, number]
32-
) {
33-
return tf.tidy(() => {
3429

35-
let out = tf.conv2d(x, params.filters, strides, 'same')
36-
out = tf.add(out, params.batch_norm_offset)
37-
return tf.relu(out)
38-
39-
})
40-
}
4130

4231
function getStridesForLayerIdx(layerIdx: number): [number, number] {
4332
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { FaceDetectionNet } from './types';
4+
5+
export function pointwiseConvLayer(
6+
x: tf.Tensor4D,
7+
params: FaceDetectionNet.PointwiseConvParams,
8+
strides: [number, number]
9+
) {
10+
return tf.tidy(() => {
11+
12+
let out = tf.conv2d(x, params.filters, strides, 'same')
13+
out = tf.add(out, params.batch_norm_offset)
14+
return tf.relu(out)
15+
16+
})
17+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { boxPredictionLayer } from './boxPredictionLayer';
4+
import { pointwiseConvLayer } from './pointwiseConvLayer';
5+
import { FaceDetectionNet } from './types';
6+
7+
export function predictionLayer(x: tf.Tensor4D, params: FaceDetectionNet.PredictionParams) {
8+
return tf.tidy(() => {
9+
10+
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
11+
const conv1 = pointwiseConvLayer(x, params.conv_1_params, [2, 2])
12+
const conv2 = pointwiseConvLayer(x, params.conv_2_params, [1, 1])
13+
const conv3 = pointwiseConvLayer(x, params.conv_3_params, [2, 2])
14+
const conv4 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
15+
const conv5 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
16+
const conv6 = pointwiseConvLayer(x, params.conv_4_params, [1, 1])
17+
const conv7 = pointwiseConvLayer(x, params.conv_5_params, [2, 2])
18+
19+
const boxPrediction0 = boxPredictionLayer(x, params.box_predictor_0_params, 3072)
20+
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
21+
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
22+
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
23+
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params, 24)
24+
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params, 6)
25+
26+
const boxPredictions = tf.concat([
27+
boxPrediction0.boxPredictionEncoding,
28+
boxPrediction1.boxPredictionEncoding,
29+
boxPrediction2.boxPredictionEncoding,
30+
boxPrediction3.boxPredictionEncoding,
31+
boxPrediction4.boxPredictionEncoding,
32+
boxPrediction5.boxPredictionEncoding
33+
])
34+
35+
const classPredictions = tf.concat([
36+
boxPrediction0.classPrediction,
37+
boxPrediction1.classPrediction,
38+
boxPrediction2.classPrediction,
39+
boxPrediction3.classPrediction,
40+
boxPrediction4.classPrediction,
41+
boxPrediction5.classPrediction
42+
])
43+
44+
return {
45+
boxPredictions,
46+
classPredictions
47+
}
48+
})
49+
}

src/faceDetectionNet/types.ts

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
export namespace FaceDetectionNet {
44

5+
export type PointwiseConvParams = {
6+
filters: tf.Tensor4D
7+
batch_norm_offset: tf.Tensor1D
8+
}
9+
510
export namespace MobileNetV1 {
611

712
export type DepthwiseConvParams = {
@@ -12,11 +17,6 @@ export namespace FaceDetectionNet {
1217
batch_norm_variance: tf.Tensor1D
1318
}
1419

15-
export type PointwiseConvParams = {
16-
filters: tf.Tensor4D
17-
batch_norm_offset: tf.Tensor1D
18-
}
19-
2020
export type ConvPairParams = {
2121
depthwise_conv_params: DepthwiseConvParams
2222
pointwise_conv_params: PointwiseConvParams
@@ -29,6 +29,31 @@ export namespace FaceDetectionNet {
2929

3030
}
3131

32+
export type ConvWithBiasParams = {
33+
filters: tf.Tensor4D
34+
bias: tf.Tensor1D
35+
}
36+
37+
export type BoxPredictionParams = {
38+
class_predictor_params: ConvWithBiasParams
39+
box_encoding_predictor_params: ConvWithBiasParams
40+
}
41+
42+
export type PredictionParams = {
43+
conv_0_params: PointwiseConvParams
44+
conv_1_params: PointwiseConvParams
45+
conv_2_params: PointwiseConvParams
46+
conv_3_params: PointwiseConvParams
47+
conv_4_params: PointwiseConvParams
48+
conv_5_params: PointwiseConvParams
49+
box_predictor_0_params: BoxPredictionParams
50+
box_predictor_1_params: BoxPredictionParams
51+
box_predictor_2_params: BoxPredictionParams
52+
box_predictor_3_params: BoxPredictionParams
53+
box_predictor_4_params: BoxPredictionParams
54+
box_predictor_5_params: BoxPredictionParams
55+
}
56+
3257
export type NetParams = {
3358
mobilenetv1_params: MobileNetV1.Params
3459
}

0 commit comments

Comments
 (0)