Skip to content

Commit ae742d9

Browse files
init mtcnn weights + weight loading + PNet
1 parent c3cfe31 commit ae742d9

17 files changed

+409
-53
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ConvParams, ExtractWeightsFunction, ParamMapping } from './types';
4+
5+
export function extractConvParamsFactory(
6+
extractWeights: ExtractWeightsFunction,
7+
paramMappings: ParamMapping[]
8+
) {
9+
10+
return function(
11+
channelsIn: number,
12+
channelsOut: number,
13+
filterSize: number,
14+
mappedPrefix: string
15+
): ConvParams {
16+
17+
const filters = tf.tensor4d(
18+
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
19+
[filterSize, filterSize, channelsIn, channelsOut]
20+
)
21+
const bias = tf.tensor1d(extractWeights(channelsOut))
22+
23+
paramMappings.push(
24+
{ paramPath: `${mappedPrefix}/filters` },
25+
{ paramPath: `${mappedPrefix}/bias` }
26+
)
27+
28+
return { filters, bias }
29+
}
30+
31+
}

src/commons/extractFCParamsFactory.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { ExtractWeightsFunction, FCParams, ParamMapping } from './types';
4+
5+
export function extractFCParamsFactory(
6+
extractWeights: ExtractWeightsFunction,
7+
paramMappings: ParamMapping[]
8+
) {
9+
10+
return function(
11+
channelsIn: number,
12+
channelsOut: number,
13+
mappedPrefix: string
14+
): FCParams {
15+
16+
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
17+
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
18+
19+
paramMappings.push(
20+
{ paramPath: `${mappedPrefix}/weights` },
21+
{ paramPath: `${mappedPrefix}/bias` }
22+
)
23+
24+
return {
25+
weights: fc_weights,
26+
bias: fc_bias
27+
}
28+
}
29+
30+
}

src/commons/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ export type ConvParams = {
55
bias: tf.Tensor1D
66
}
77

8+
export type FCParams = {
9+
weights: tf.Tensor2D
10+
bias: tf.Tensor1D
11+
}
12+
813
export type ExtractWeightsFunction = (numWeights: number) => Float32Array
914

1015
export type BatchReshapeInfo = {

src/faceLandmarkNet/extractParams.ts

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import * as tf from '@tensorflow/tfjs-core';
2-
1+
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
2+
import { extractFCParamsFactory } from '../commons/extractFCParamsFactory';
33
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
4-
import { ConvParams, ParamMapping } from '../commons/types';
5-
import { FCParams, NetParams } from './types';
4+
import { ParamMapping } from '../commons/types';
5+
import { NetParams } from './types';
66

77
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
88

@@ -13,42 +13,8 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
1313
getRemainingWeights
1414
} = extractWeightsFactory(weights)
1515

16-
function extractConvParams(
17-
channelsIn: number,
18-
channelsOut: number,
19-
filterSize: number,
20-
mappedPrefix: string
21-
): ConvParams {
22-
23-
const filters = tf.tensor4d(
24-
extractWeights(channelsIn * channelsOut * filterSize * filterSize),
25-
[filterSize, filterSize, channelsIn, channelsOut]
26-
)
27-
const bias = tf.tensor1d(extractWeights(channelsOut))
28-
29-
paramMappings.push(
30-
{ paramPath: `${mappedPrefix}/filters` },
31-
{ paramPath: `${mappedPrefix}/bias` }
32-
)
33-
34-
return { filters, bias }
35-
}
36-
37-
function extractFcParams(channelsIn: number, channelsOut: number, mappedPrefix: string): FCParams {
38-
39-
const fc_weights = tf.tensor2d(extractWeights(channelsIn * channelsOut), [channelsIn, channelsOut])
40-
const fc_bias = tf.tensor1d(extractWeights(channelsOut))
41-
42-
paramMappings.push(
43-
{ paramPath: `${mappedPrefix}/weights` },
44-
{ paramPath: `${mappedPrefix}/bias` }
45-
)
46-
47-
return {
48-
weights: fc_weights,
49-
bias: fc_bias
50-
}
51-
}
16+
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
17+
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
5218

5319
const conv0 = extractConvParams(3, 32, 3, 'conv0')
5420
const conv1 = extractConvParams(32, 64, 3, 'conv1')
@@ -58,8 +24,8 @@ export function extractParams(weights: Float32Array): { params: NetParams, param
5824
const conv5 = extractConvParams(64, 128, 3, 'conv5')
5925
const conv6 = extractConvParams(128, 128, 3, 'conv6')
6026
const conv7 = extractConvParams(128, 256, 3, 'conv7')
61-
const fc0 = extractFcParams(6400, 1024, 'fc0')
62-
const fc1 = extractFcParams(1024, 136, 'fc1')
27+
const fc0 = extractFCParams(6400, 1024, 'fc0')
28+
const fc1 = extractFCParams(1024, 136, 'fc1')
6329

6430
if (getRemainingWeights().length !== 0) {
6531
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)

src/faceLandmarkNet/fullyConnectedLayer.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { FCParams } from './types';
3+
import { FCParams } from '../commons/types';
44

55
export function fullyConnectedLayer(
66
x: tf.Tensor2D,

src/faceLandmarkNet/loadQuantizedParams.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import * as tf from '@tensorflow/tfjs-core';
33
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
44
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
55
import { loadWeightMap } from '../commons/loadWeightMap';
6-
import { ConvParams, ParamMapping } from '../commons/types';
7-
import { FCParams, NetParams } from './types';
6+
import { ConvParams, FCParams, ParamMapping } from '../commons/types';
7+
import { NetParams } from './types';
88

99
const DEFAULT_MODEL_NAME = 'face_landmark_68_model'
1010

src/faceLandmarkNet/types.ts

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
import * as tf from '@tensorflow/tfjs-core';
2-
3-
import { ConvParams } from '../commons/types';
4-
5-
export type FCParams = {
6-
weights: tf.Tensor2D
7-
bias: tf.Tensor1D
8-
}
1+
import { ConvParams, FCParams } from '../commons/types';
92

103
export type NetParams = {
114
conv0: ConvParams

src/mtcnn/Mtcnn.ts

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { NeuralNetwork } from '../commons/NeuralNetwork';
4+
import { NetInput } from '../NetInput';
5+
import { toNetInput } from '../toNetInput';
6+
import { TNetInput } from '../types';
7+
import { extractParams } from './extractParams';
8+
import { stage1 } from './stage1';
9+
import { NetParams } from './types';
10+
11+
export class Mtcnn extends NeuralNetwork<NetParams> {
12+
13+
constructor() {
14+
super('Mtcnn')
15+
}
16+
17+
public forwardInput(input: NetInput, minFaceSize: number = 20, scaleFactor: number = 0.709): tf.Tensor2D {
18+
19+
const { params } = this
20+
21+
if (!params) {
22+
throw new Error('Mtcnn - load model before inference')
23+
}
24+
25+
return tf.tidy(() => {
26+
const imgTensor = tf.expandDims(input.inputs[0]).toFloat() as tf.Tensor4D
27+
28+
function pyramidDown(minFaceSize: number, scaleFactor: number, dims: number[]): number[] {
29+
30+
const [height, width] = dims
31+
const m = 12 / minFaceSize
32+
33+
const scales = []
34+
35+
let minLayer = Math.min(height, width) * m
36+
let exp = 0
37+
while (minLayer >= 12) {
38+
scales.push(m * Math.pow(scaleFactor, exp))
39+
minLayer = minLayer * scaleFactor
40+
exp += 1
41+
}
42+
43+
return scales
44+
}
45+
46+
const scales = pyramidDown(minFaceSize, scaleFactor, imgTensor.shape)
47+
const out1 = stage1(imgTensor, scales, params.pnet)
48+
49+
return tf.tensor2d([0], [1, 1])
50+
})
51+
}
52+
53+
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
54+
return this.forwardInput(await toNetInput(input, true))
55+
}
56+
57+
protected extractParams(weights: Float32Array) {
58+
return extractParams(weights)
59+
}
60+
}

src/mtcnn/PNet.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { convLayer } from '../commons/convLayer';
4+
import { sharedLayer } from './sharedLayers';
5+
import { PNetParams } from './types';
6+
7+
export function PNet(x: tf.Tensor4D, params: PNetParams): { prob: tf.Tensor3D, convOut: tf.Tensor4D } {
8+
return tf.tidy(() => {
9+
10+
let out = sharedLayer(x, params)
11+
const conv = convLayer(out, params.conv4_1, 'valid')
12+
// TODO: tf.reduce_max <=> tf.max ?
13+
const logits = tf.sub(conv, tf.max(conv, 3))
14+
const prob = tf.softmax(logits, 3) as tf.Tensor3D
15+
const convOut = convLayer(out, params.conv4_2, 'valid')
16+
17+
return { prob, convOut }
18+
})
19+
}

src/mtcnn/extractParams.ts

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { extractConvParamsFactory } from '../commons/extractConvParamsFactory';
4+
import { extractFCParamsFactory } from '../commons/extractFCParamsFactory';
5+
import { extractWeightsFactory } from '../commons/extractWeightsFactory';
6+
import { ExtractWeightsFunction, ParamMapping } from '../commons/types';
7+
import { NetParams, PNetParams, RNetParams, SharedParams, ONetParams } from './types';
8+
9+
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
10+
11+
const extractConvParams = extractConvParamsFactory(extractWeights, paramMappings)
12+
const extractFCParams = extractFCParamsFactory(extractWeights, paramMappings)
13+
14+
function extractPReluParams(size: number, paramPath: string): tf.Tensor1D {
15+
const alpha = tf.tensor1d(extractWeights(size))
16+
paramMappings.push({ paramPath })
17+
return alpha
18+
}
19+
20+
function extractSharedParams(numFilters: number[], mappedPrefix: string, isRnet: boolean = false): SharedParams {
21+
22+
const conv1 = extractConvParams(numFilters[0], numFilters[1], 3, `${mappedPrefix}/conv1`)
23+
const prelu1_alpha = extractPReluParams(numFilters[1], `${mappedPrefix}/prelu1_alpha`)
24+
const conv2 = extractConvParams(numFilters[1], numFilters[2], 3, `${mappedPrefix}/conv2`)
25+
const prelu2_alpha = extractPReluParams(numFilters[2], `${mappedPrefix}/prelu2_alpha`)
26+
const conv3 = extractConvParams(numFilters[2], numFilters[3], isRnet ? 2 : 3, `${mappedPrefix}/conv3`)
27+
const prelu3_alpha = extractPReluParams(numFilters[3], `${mappedPrefix}/prelu3_alpha`)
28+
29+
return { conv1, prelu1_alpha, conv2, prelu2_alpha, conv3, prelu3_alpha }
30+
}
31+
32+
function extractPNetParams(): PNetParams {
33+
34+
const sharedParams = extractSharedParams([3, 10, 16, 32], 'pnet')
35+
const conv4_1 = extractConvParams(32, 2, 1, 'pnet/conv4_1')
36+
const conv4_2 = extractConvParams(32, 4, 1, 'pnet/conv4_2')
37+
38+
return { ...sharedParams, conv4_1, conv4_2 }
39+
}
40+
41+
function extractRNetParams(): RNetParams {
42+
43+
const sharedParams = extractSharedParams([3, 28, 48, 64], 'rnet')
44+
const fc1 = extractFCParams(576, 128, 'rnet/fc1')
45+
const prelu4_alpha = extractPReluParams(128, 'rnet/prelu4_alpha')
46+
const fc2_1 = extractFCParams(128, 2, 'rnet/fc2_1')
47+
const fc2_2 = extractFCParams(128, 4, 'rnet/fc2_2')
48+
49+
return { ...sharedParams, fc1, prelu4_alpha, fc2_1, fc2_2 }
50+
}
51+
52+
function extractONetParams(): ONetParams {
53+
54+
const sharedParams = extractSharedParams([3, 32, 64, 64], 'onet')
55+
const conv4 = extractConvParams(64, 128, 2, 'onet/conv4')
56+
const prelu4_alpha = extractPReluParams(128, 'onet/prelu4_alpha')
57+
const fc1 = extractFCParams(1152, 256, 'onet/fc1')
58+
const prelu5_alpha = extractPReluParams(256, 'onet/prelu4_alpha')
59+
const fc2_1 = extractFCParams(256, 2, 'onet/fc2_1')
60+
const fc2_2 = extractFCParams(256, 4, 'onet/fc2_2')
61+
const fc2_3 = extractFCParams(256, 10, 'onet/fc2_2')
62+
63+
return { ...sharedParams, conv4, prelu4_alpha, fc1, prelu5_alpha, fc2_1, fc2_2, fc2_3 }
64+
}
65+
66+
return {
67+
extractPNetParams,
68+
extractRNetParams,
69+
extractONetParams
70+
}
71+
72+
}
73+
74+
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
75+
76+
const {
77+
extractWeights,
78+
getRemainingWeights
79+
} = extractWeightsFactory(weights)
80+
81+
const paramMappings: ParamMapping[] = []
82+
83+
const {
84+
extractPNetParams,
85+
extractRNetParams,
86+
extractONetParams
87+
} = extractorsFactory(extractWeights, paramMappings)
88+
89+
const pnet = extractPNetParams()
90+
const rnet = extractRNetParams()
91+
const onet = extractONetParams()
92+
93+
return { params: { pnet, rnet, onet }, paramMappings }
94+
}

src/mtcnn/index.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { Mtcnn } from './Mtcnn';
2+
3+
export * from './Mtcnn';
4+
5+
export function mtcnn(weights: Float32Array) {
6+
const net = new Mtcnn()
7+
net.extractWeights(weights)
8+
return net
9+
}

src/mtcnn/prelu.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
export function prelu(x: tf.Tensor4D, alpha: tf.Tensor1D): tf.Tensor4D {
4+
return tf.tidy(() =>
5+
tf.add(
6+
tf.relu(x),
7+
tf.mul(alpha, tf.neg(tf.relu(tf.neg(x))))
8+
)
9+
)
10+
}

src/mtcnn/sharedLayers.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { convLayer } from '../commons/convLayer';
4+
import { prelu } from './prelu';
5+
import { SharedParams } from './types';
6+
7+
export function sharedLayer(x: tf.Tensor4D, params: SharedParams, isPnet: boolean = false) {
8+
return tf.tidy(() => {
9+
10+
let out = convLayer(x, params.conv1, 'valid')
11+
out = prelu(out, params.prelu1_alpha)
12+
out = tf.maxPool(out, isPnet ? [2, 2]: [3, 3], [2, 2], 'same')
13+
out = convLayer(out, params.conv2, 'valid')
14+
out = prelu(out, params.prelu2_alpha)
15+
out = isPnet ? out : tf.maxPool(out, [3, 3], [2, 2], 'valid')
16+
out = convLayer(out, params.conv3, 'valid')
17+
out = prelu(out, params.prelu3_alpha)
18+
19+
return out
20+
})
21+
}

0 commit comments

Comments
 (0)