Skip to content

Commit 2efff4f

Browse files
implemented mobilenetv1 for face detector
1 parent 6f8221a commit 2efff4f

File tree

7 files changed

+117
-30
lines changed

7 files changed

+117
-30
lines changed

src/faceDetectionNet/extractParams.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ import { FaceDetectionNet } from './types';
55
function mobilenetV1WeightsExtractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
66

77
function extractDepthwiseConvParams(numChannels: number): FaceDetectionNet.MobileNetV1.DepthwiseConvParams {
8-
const weights = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1])
9-
const batch_norm_gamma = tf.tensor1d(extractWeights(numChannels))
10-
const batch_norm_beta = tf.tensor1d(extractWeights(numChannels))
8+
const filters = tf.tensor4d(extractWeights(3 * 3 * numChannels), [3, 3, numChannels, 1])
9+
const batch_norm_scale = tf.tensor1d(extractWeights(numChannels))
10+
const batch_norm_offset = tf.tensor1d(extractWeights(numChannels))
1111
const batch_norm_mean = tf.tensor1d(extractWeights(numChannels))
1212
const batch_norm_variance = tf.tensor1d(extractWeights(numChannels))
1313

1414
return {
15-
weights,
16-
batch_norm_gamma,
17-
batch_norm_beta,
15+
filters,
16+
batch_norm_scale,
17+
batch_norm_offset,
1818
batch_norm_mean,
1919
batch_norm_variance
2020
}
2121
}
2222

2323
function extractPointwiseConvParams(channelsIn: number, channelsOut: number): FaceDetectionNet.MobileNetV1.PointwiseConvParams {
24-
const weights = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
24+
const filters = tf.tensor4d(extractWeights(channelsIn * channelsOut), [1, 1, channelsIn, channelsOut])
2525
const batch_norm_offset = tf.tensor1d(extractWeights(channelsOut))
2626

2727
return {
28-
weights,
28+
filters,
2929
batch_norm_offset
3030
}
3131
}
@@ -59,7 +59,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
5959
function extractMobilenetV1Params(): FaceDetectionNet.MobileNetV1.Params {
6060

6161
const conv_0_params = {
62-
weights: tf.tensor4d(extractWeights(3 * 3 * 3 * 32), [3, 3, 3, 32]),
62+
filters: tf.tensor4d(extractWeights(3 * 3 * 3 * 32), [3, 3, 3, 32]),
6363
batch_norm_offset: tf.tensor1d(extractWeights(32))
6464

6565
}

src/faceDetectionNet/index.ts

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,58 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { resizeLayer } from './resizeLayer';
3+
import { isFloat } from '../utils';
44
import { extractParams } from './extractParams';
55
import { mobileNetV1 } from './mobileNetV1';
6+
import { resizeLayer } from './resizeLayer';
7+
8+
function fromData(input: number[]): tf.Tensor4D {
9+
const pxPerChannel = input.length / 3
10+
const dim = Math.sqrt(pxPerChannel)
11+
12+
if (isFloat(dim)) {
13+
throw new Error(`invalid input size: ${dim}x${dim}x3 (array length: ${input.length})`)
14+
}
15+
16+
return tf.tensor4d(input as number[], [1, 580, 580, 3])
17+
}
18+
19+
function fromImageData(input: ImageData[]) {
20+
const idx = input.findIndex(data => !(data instanceof ImageData))
21+
if (idx !== -1) {
22+
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
23+
}
24+
25+
const imgTensors = input
26+
.map(data => tf.fromPixels(data))
27+
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
28+
29+
return tf.cast(tf.concat(imgTensors, 0), 'float32')
30+
}
631

732
export function faceDetectionNet(weights: Float32Array) {
833
const params = extractParams(weights)
934

10-
async function forward(input: ImageData|ImageData[]) {
35+
async function forward(input: ImageData|ImageData[]|number[]) {
36+
return tf.tidy(() => {
1137

12-
const imgTensors = (input instanceof ImageData ? [input] : input)
13-
.map(data => tf.fromPixels(data))
14-
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
38+
const imgDataArray = input instanceof ImageData
39+
? [input]
40+
: (
41+
input[0] instanceof ImageData
42+
? input as ImageData[]
43+
: null
44+
)
1545

16-
const imgTensor = tf.cast(tf.concat(imgTensors, 0), 'float32')
46+
const imgTensor = imgDataArray !== null
47+
? fromImageData(imgDataArray)
48+
: fromData(input as number[])
1749

18-
let out = resizeLayer(imgTensor) as tf.Tensor4D
50+
let out = resizeLayer(imgTensor) as tf.Tensor4D
51+
out = mobileNetV1(out, params.mobilenetv1_params)
1952

20-
out = mobileNetV1(out, params.mobilenetv1_params)
53+
return out
2154

22-
return out
55+
})
2356
}
2457

2558
return {

src/faceDetectionNet/mobileNetV1.ts

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,57 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { FaceDetectionNet } from './types';
44

5+
const epsilon = 0.0010000000474974513
6+
7+
function depthwiseConvLayer(
8+
x: tf.Tensor4D,
9+
params: FaceDetectionNet.MobileNetV1.DepthwiseConvParams,
10+
strides: [number, number]
11+
) {
12+
return tf.tidy(() => {
13+
14+
let out = tf.depthwiseConv2d(x, params.filters, strides, 'same')
15+
out = tf.batchNormalization<tf.Rank.R4>(
16+
out,
17+
params.batch_norm_mean,
18+
params.batch_norm_variance,
19+
epsilon,
20+
params.batch_norm_scale,
21+
params.batch_norm_offset
22+
)
23+
return tf.relu(out)
24+
25+
})
26+
}
27+
28+
function pointwiseConvLayer(
29+
x: tf.Tensor4D,
30+
params: FaceDetectionNet.MobileNetV1.PointwiseConvParams,
31+
strides: [number, number]
32+
) {
33+
return tf.tidy(() => {
34+
35+
let out = tf.conv2d(x, params.filters, strides, 'same')
36+
out = tf.add(out, params.batch_norm_offset)
37+
return tf.relu(out)
38+
39+
})
40+
}
41+
42+
function getStridesForLayerIdx(layerIdx: number): [number, number] {
43+
return [2, 4, 6, 12].some(idx => idx === layerIdx) ? [2, 2] : [1, 1]
44+
}
45+
546
export function mobileNetV1(x: tf.Tensor4D, params: FaceDetectionNet.MobileNetV1.Params) {
6-
return x
47+
return tf.tidy(() => {
48+
49+
let out = pointwiseConvLayer(x, params.conv_0_params, [2, 2])
50+
params.conv_pair_params.forEach((param, i) => {
51+
const depthwiseConvStrides = getStridesForLayerIdx(i + 1)
52+
out = depthwiseConvLayer(out, param.depthwise_conv_params, depthwiseConvStrides)
53+
out = pointwiseConvLayer(out, param.pointwise_conv_params, [1, 1])
54+
})
55+
return out
56+
57+
})
758
}

src/faceDetectionNet/resizeLayer.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
// TODO: hardcoded params
43
const resizedImageSize = [512, 512] as [number, number]
54
const weight = tf.scalar(0.007843137718737125)
65
const bias = tf.scalar(1)
76

87
export function resizeLayer(x: tf.Tensor4D) {
9-
const resized = tf.image.resizeBilinear(x, resizedImageSize, false)
10-
return tf.sub(tf.mul(resized, weight), bias)
8+
return tf.tidy(() => {
9+
10+
const resized = tf.image.resizeBilinear(x, resizedImageSize, false)
11+
return tf.sub(tf.mul(resized, weight), bias)
12+
13+
})
1114
}

src/faceDetectionNet/types.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ export namespace FaceDetectionNet {
55
export namespace MobileNetV1 {
66

77
export type DepthwiseConvParams = {
8-
weights: tf.Tensor4D // [3, 3, ch, 1]
9-
batch_norm_gamma: tf.Tensor1D
10-
batch_norm_beta: tf.Tensor1D
8+
filters: tf.Tensor4D
9+
batch_norm_scale: tf.Tensor1D
10+
batch_norm_offset: tf.Tensor1D
1111
batch_norm_mean: tf.Tensor1D
1212
batch_norm_variance: tf.Tensor1D
1313
}
1414

1515
export type PointwiseConvParams = {
16-
weights: tf.Tensor4D // [1, 1, ch_in, ch_out]
16+
filters: tf.Tensor4D
1717
batch_norm_offset: tf.Tensor1D
1818
}
1919

src/faceRecognitionNet/extractParams.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { isFloat } from '../utils';
34
import { FaceRecognitionNet } from './types';
45

5-
function isFloat(num: number) {
6-
return num % 1 !== 0
7-
}
8-
96
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
107

118
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {

src/utils.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export function isFloat(num: number) {
2+
return num % 1 !== 0
3+
}

0 commit comments

Comments
 (0)