Skip to content

Commit 62163b6

Browse files
init faceLandmarkNet
1 parent c378f93 commit 62163b6

File tree

13 files changed

+199
-49
lines changed

13 files changed

+199
-49
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ConvParams, ExtractWeightsFunction } from './types';
4+
5+
export function extractConvParamsFactory(extractWeights: ExtractWeightsFunction) {
6+
return function (
7+
channelsIn: number,
8+
channelsOut: number,
9+
filterSize: number
10+
): ConvParams {
11+
const filters = tf.tensor4d(
12+
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
13+
[filterSize, filterSize, channelsIn, channelsOut]
14+
)
15+
const bias = tf.tensor1d(extractWeights(channelsOut))
16+
17+
return {
18+
filters,
19+
bias
20+
}
21+
}
22+
}

src/commons/extractWeightsFactory.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export function extractWeightsFactory(weights: Float32Array) {
2+
let remainingWeights = weights
3+
4+
function extractWeights(numWeights: number): Float32Array {
5+
const ret = remainingWeights.slice(0, numWeights)
6+
remainingWeights = remainingWeights.slice(numWeights)
7+
return ret
8+
}
9+
10+
function getRemainingWeights(): Float32Array {
11+
return remainingWeights
12+
}
13+
14+
return {
15+
extractWeights,
16+
getRemainingWeights
17+
}
18+
}

src/commons/types.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
export type ConvParams = {
4+
filters: tf.Tensor4D
5+
bias: tf.Tensor1D
6+
}
7+
8+
export type ExtractWeightsFunction = (numWeights: number) => Float32Array

src/faceDetectionNet/boxPredictionLayer.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { ConvParams } from '../commons/types';
34
import { FaceDetectionNet } from './types';
45

56
function convWithBias(
67
x: tf.Tensor4D,
7-
params: FaceDetectionNet.ConvWithBiasParams
8+
params: ConvParams
89
) {
910
return tf.tidy(() =>
1011
tf.add(

src/faceDetectionNet/extractParams.ts

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
4+
import { ConvParams } from '../commons/types';
35
import { FaceDetectionNet } from './types';
46

57
function extractorsFactory(extractWeights: (numWeights: number) => Float32Array) {
@@ -20,11 +22,11 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
2022
}
2123
}
2224

23-
function extractConvWithBiasParams(
25+
function extractConvParams(
2426
channelsIn: number,
2527
channelsOut: number,
2628
filterSize: number
27-
): FaceDetectionNet.ConvWithBiasParams {
29+
): ConvParams {
2830
const filters = tf.tensor4d(
2931
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
3032
[filterSize, filterSize, channelsIn, channelsOut]
@@ -45,7 +47,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
4547
const {
4648
filters,
4749
bias
48-
} = extractConvWithBiasParams(channelsIn, channelsOut, filterSize)
50+
} = extractConvParams(channelsIn, channelsOut, filterSize)
4951

5052
return {
5153
filters,
@@ -104,18 +106,18 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
104106
const conv_6_params = extractPointwiseConvParams(256, 64, 1)
105107
const conv_7_params = extractPointwiseConvParams(64, 128, 3)
106108

107-
const box_encoding_0_predictor_params = extractConvWithBiasParams(512, 12, 1)
108-
const class_predictor_0_params = extractConvWithBiasParams(512, 9, 1)
109-
const box_encoding_1_predictor_params = extractConvWithBiasParams(1024, 24, 1)
110-
const class_predictor_1_params = extractConvWithBiasParams(1024, 18, 1)
111-
const box_encoding_2_predictor_params = extractConvWithBiasParams(512, 24, 1)
112-
const class_predictor_2_params = extractConvWithBiasParams(512, 18, 1)
113-
const box_encoding_3_predictor_params = extractConvWithBiasParams(256, 24, 1)
114-
const class_predictor_3_params = extractConvWithBiasParams(256, 18, 1)
115-
const box_encoding_4_predictor_params = extractConvWithBiasParams(256, 24, 1)
116-
const class_predictor_4_params = extractConvWithBiasParams(256, 18, 1)
117-
const box_encoding_5_predictor_params = extractConvWithBiasParams(128, 24, 1)
118-
const class_predictor_5_params = extractConvWithBiasParams(128, 18, 1)
109+
const box_encoding_0_predictor_params = extractConvParams(512, 12, 1)
110+
const class_predictor_0_params = extractConvParams(512, 9, 1)
111+
const box_encoding_1_predictor_params = extractConvParams(1024, 24, 1)
112+
const class_predictor_1_params = extractConvParams(1024, 18, 1)
113+
const box_encoding_2_predictor_params = extractConvParams(512, 24, 1)
114+
const class_predictor_2_params = extractConvParams(512, 18, 1)
115+
const box_encoding_3_predictor_params = extractConvParams(256, 24, 1)
116+
const class_predictor_3_params = extractConvParams(256, 18, 1)
117+
const box_encoding_4_predictor_params = extractConvParams(256, 24, 1)
118+
const class_predictor_4_params = extractConvParams(256, 18, 1)
119+
const box_encoding_5_predictor_params = extractConvParams(128, 24, 1)
120+
const class_predictor_5_params = extractConvParams(128, 18, 1)
119121

120122
const box_predictor_0_params = {
121123
box_encoding_predictor_params: box_encoding_0_predictor_params,
@@ -169,11 +171,10 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
169171
}
170172

171173
export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams {
172-
const extractWeights = (numWeights: number): Float32Array => {
173-
const ret = weights.slice(0, numWeights)
174-
weights = weights.slice(numWeights)
175-
return ret
176-
}
174+
const {
175+
extractWeights,
176+
getRemainingWeights
177+
} = extractWeightsFactory(weights)
177178

178179
const {
179180
extractMobilenetV1Params,
@@ -190,8 +191,8 @@ export function extractParams(weights: Float32Array): FaceDetectionNet.NetParams
190191
extra_dim
191192
}
192193

193-
if (weights.length !== 0) {
194-
throw new Error(`weights remaing after extract: ${weights.length}`)
194+
if (getRemainingWeights().length !== 0) {
195+
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
195196
}
196197

197198
return {

src/faceDetectionNet/types.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { ConvParams } from '../commons/types';
4+
35
export namespace FaceDetectionNet {
46

57
export type PointwiseConvParams = {
@@ -29,14 +31,9 @@ export namespace FaceDetectionNet {
2931

3032
}
3133

32-
export type ConvWithBiasParams = {
33-
filters: tf.Tensor4D
34-
bias: tf.Tensor1D
35-
}
36-
3734
export type BoxPredictionParams = {
38-
box_encoding_predictor_params: ConvWithBiasParams
39-
class_predictor_params: ConvWithBiasParams
35+
box_encoding_predictor_params: ConvParams
36+
class_predictor_params: ConvParams
4037
}
4138

4239
export type PredictionLayerParams = {

src/faceLandmarkNet/extractParams.ts

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
2+
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
3+
import { FaceLandmarkNet } from './types';
4+
import * as tf from '@tensorflow/tfjs-core';
5+
6+
export function extractParams(weights: Float32Array): FaceLandmarkNet.NetParams {
7+
const {
8+
extractWeights,
9+
getRemainingWeights
10+
} = extractWeightsFactory(weights)
11+
12+
const extractConvParams = extractConvParamsFactory(extractWeights)
13+
14+
function extractFcParams(channelsIn: number, channelsOut: number,): FaceLandmarkNet.FCParams {
15+
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
16+
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
17+
return {
18+
weights: fc_weights,
19+
bias: fc_bias
20+
}
21+
}
22+
23+
if (getRemainingWeights().length !== 0) {
24+
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
25+
}
26+
27+
return {
28+
conv0_params: extractConvParams(3, 32, 3),
29+
conv1_params: extractConvParams(32, 64, 3),
30+
conv2_params: extractConvParams(64, 64, 3),
31+
conv3_params: extractConvParams(64, 64, 3),
32+
conv4_params: extractConvParams(64, 64, 3),
33+
conv5_params: extractConvParams(64, 128, 3),
34+
conv6_params: extractConvParams(128, 128, 3),
35+
conv7_params: extractConvParams(128, 256, 3),
36+
fc0_params: extractFcParams(6400, 1024),
37+
fc1_params:extractFcParams(1024, 136)
38+
}
39+
}

src/faceLandmarkNet/index.ts

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { getImageTensor } from '../getImageTensor';
4+
import { NetInput } from '../NetInput';
5+
import { padToSquare } from '../padToSquare';
6+
import { TNetInput } from '../types';
7+
import { extractParams } from './extractParams';
8+
9+
export function faceLandmarkNet(weights: Float32Array) {
10+
const params = extractParams(weights)
11+
12+
function forward(input: tf.Tensor | NetInput | TNetInput) {
13+
return tf.tidy(() => {
14+
15+
let x = padToSquare(getImageTensor(input), true)
16+
// work with 128 x 128 sized face images
17+
if (x.shape[1] !== 128 || x.shape[2] !== 128) {
18+
x = tf.image.resizeBilinear(x, [128, 128])
19+
}
20+
21+
// pool 1
22+
tf.maxPool(x, [2, 2], [2, 2], 'valid')
23+
// pool 2
24+
tf.maxPool(x, [2, 2], [2, 2], 'valid')
25+
// pool 3
26+
tf.maxPool(x, [2, 2], [2, 2], 'valid')
27+
// pool 4
28+
tf.maxPool(x, [2, 2], [1, 1], 'valid')
29+
// TODO
30+
31+
return x
32+
})
33+
}
34+
35+
return {
36+
forward
37+
}
38+
}

src/faceLandmarkNet/types.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ConvParams } from '../commons/types';
4+
5+
export namespace FaceLandmarkNet {
6+
7+
export type FCParams = {
8+
weights: tf.Tensor2D
9+
bias: tf.Tensor1D
10+
}
11+
12+
export type NetParams = {
13+
conv0_params: ConvParams
14+
conv1_params: ConvParams
15+
conv2_params: ConvParams
16+
conv3_params: ConvParams
17+
conv4_params: ConvParams
18+
conv5_params: ConvParams
19+
conv6_params: ConvParams
20+
conv7_params: ConvParams
21+
fc0_params: FCParams
22+
fc1_params: FCParams
23+
}
24+
25+
}

src/faceRecognitionNet/convLayer.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ function convLayer(
1111
withRelu: boolean,
1212
padding: 'valid' | 'same' = 'same'
1313
): tf.Tensor4D {
14-
const { filters, biases } = params.conv
14+
const { filters, bias } = params.conv
1515

1616
let out = tf.conv2d(x, filters, [stride, stride], padding)
17-
out = tf.add(out, biases)
17+
out = tf.add(out, bias)
1818
out = scale(out, params.scale)
1919
return withRelu ? tf.relu(out) : out
2020
}

0 commit comments

Comments
 (0)