Skip to content

Commit 5d87492

Browse files
implemented mtcnn model loading from url + expose mtcnn to global api + fixed some minor issues
1 parent 4eed7a3 commit 5d87492

21 files changed

+218
-57
lines changed

src/faceDetectionNet/FaceDetection.ts renamed to src/FaceDetection.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import { Rect } from '../Rect';
2-
import { Dimensions } from '../types';
1+
import { Rect } from './Rect';
2+
import { Dimensions } from './types';
33

44
export class FaceDetection {
55
private _score: number

src/FullFaceDescription.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { FaceDetection } from './faceDetectionNet/FaceDetection';
1+
import { FaceDetection } from './FaceDetection';
22
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
33

44
export class FullFaceDescription {

src/drawing/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
1+
import { FaceDetection } from '../FaceDetection';
22
import { FaceLandmarks68 } from '../faceLandmarkNet';
33
import { FaceLandmarks } from '../FaceLandmarks';
44
import { Point } from '../Point';

src/extractFaceTensors.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { FaceDetection } from './faceDetectionNet/FaceDetection';
3+
import { FaceDetection } from './FaceDetection';
44
import { Rect } from './Rect';
55
import { toNetInput } from './toNetInput';
66
import { TNetInput } from './types';

src/extractFaces.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import { FaceDetection } from './faceDetectionNet/FaceDetection';
1+
import { FaceDetection } from './FaceDetection';
22
import { Rect } from './Rect';
33
import { toNetInput } from './toNetInput';
44
import { TNetInput } from './types';
55
import { createCanvas, getContext2dOrThrow, imageTensorToCanvas } from './utils';
6-
import * as tf from '@tensorflow/tfjs-core';
76

87
/**
98
* Extracts the image regions containing the detected faces.

src/faceDetectionNet/FaceDetectionNet.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { NeuralNetwork } from '../commons/NeuralNetwork';
4+
import { FaceDetection } from '../FaceDetection';
45
import { NetInput } from '../NetInput';
56
import { Rect } from '../Rect';
67
import { toNetInput } from '../toNetInput';
78
import { TNetInput } from '../types';
89
import { extractParams } from './extractParams';
9-
import { FaceDetection } from './FaceDetection';
1010
import { loadQuantizedParams } from './loadQuantizedParams';
1111
import { mobileNetV1 } from './mobileNetV1';
1212
import { nonMaxSuppression } from './nonMaxSuppression';

src/faceDetectionNet/index.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import { FaceDetectionNet } from './FaceDetectionNet';
22

33
export * from './FaceDetectionNet';
4-
export * from './FaceDetection';
54

6-
export function faceDetectionNet(weights: Float32Array) {
5+
export function createFaceDetectionNet(weights: Float32Array) {
76
const net = new FaceDetectionNet()
87
net.extractWeights(weights)
98
return net
9+
}
10+
11+
export function faceDetectionNet(weights: Float32Array) {
12+
console.warn('faceDetectionNet(weights: Float32Array) will be deprecated in future, use createFaceDetectionNet instead')
13+
return createFaceDetectionNet(weights)
1014
}

src/faceDetectionNet/loadQuantizedParams.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import { tf } from '..';
1+
import * as tf from '@tensorflow/tfjs-core';
2+
23
import { disposeUnusedWeightTensors } from '../commons/disposeUnusedWeightTensors';
34
import { extractWeightEntryFactory } from '../commons/extractWeightEntryFactory';
4-
import { isTensor1D, isTensor3D, isTensor4D } from '../commons/isTensor';
5+
import { isTensor3D } from '../commons/isTensor';
56
import { loadWeightMap } from '../commons/loadWeightMap';
67
import { ConvParams, ParamMapping } from '../commons/types';
78
import { BoxPredictionParams, MobileNetV1, NetParams, PointwiseConvParams, PredictionLayerParams } from './types';

src/faceLandmarkNet/FaceLandmarks68.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import { getCenterPoint } from '../commons/getCenterPoint';
2-
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
2+
import { FaceDetection } from '../FaceDetection';
33
import { FaceLandmarks } from '../FaceLandmarks';
44
import { IPoint, Point } from '../Point';
55
import { Rect } from '../Rect';
6-
import { Dimensions } from '../types';
76

87
// face alignment constants
98
const relX = 0.5
@@ -70,7 +69,7 @@ export class FaceLandmarks68 extends FaceLandmarks {
7069
* @returns The bounding box of the aligned face.
7170
*/
7271
public align(
73-
detection?: Rect
72+
detection?: FaceDetection | Rect
7473
): Rect {
7574
if (detection) {
7675
const box = detection instanceof FaceDetection

src/faceLandmarkNet/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ import { FaceLandmarkNet } from './FaceLandmarkNet';
33
export * from './FaceLandmarkNet';
44
export * from './FaceLandmarks68';
55

6-
export function faceLandmarkNet(weights: Float32Array) {
6+
export function createFaceLandmarkNet(weights: Float32Array) {
77
const net = new FaceLandmarkNet()
88
net.extractWeights(weights)
99
return net
10+
}
11+
12+
export function faceLandmarkNet(weights: Float32Array) {
13+
console.warn('faceLandmarkNet(weights: Float32Array) will be deprecated in future, use createFaceLandmarkNet instead')
14+
return createFaceLandmarkNet(weights)
1015
}

src/faceRecognitionNet/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@ import { FaceRecognitionNet } from './FaceRecognitionNet';
22

33
export * from './FaceRecognitionNet';
44

5-
export function faceRecognitionNet(weights: Float32Array) {
5+
export function createFaceRecognitionNet(weights: Float32Array) {
66
const net = new FaceRecognitionNet()
77
net.extractWeights(weights)
88
return net
9+
}
10+
11+
export function faceRecognitionNet(weights: Float32Array) {
12+
console.warn('faceRecognitionNet(weights: Float32Array) will be deprecated in future, use createFaceRecognitionNet instead')
13+
return createFaceRecognitionNet(weights)
914
}

src/globalApi.ts

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,53 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { allFacesFactory } from './allFacesFactory';
4-
import { FaceDetection } from './faceDetectionNet/FaceDetection';
4+
import { FaceDetection } from './FaceDetection';
55
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
66
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
77
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
88
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
99
import { FullFaceDescription } from './FullFaceDescription';
10+
import { getDefaultMtcnnForwardParams } from './mtcnn/getDefaultMtcnnForwardParams';
11+
import { Mtcnn } from './mtcnn/Mtcnn';
12+
import { MtcnnForwardParams, MtcnnResult } from './mtcnn/types';
1013
import { NetInput } from './NetInput';
1114
import { TNetInput } from './types';
1215

1316
export const detectionNet = new FaceDetectionNet()
1417
export const landmarkNet = new FaceLandmarkNet()
1518
export const recognitionNet = new FaceRecognitionNet()
1619

20+
// nets need more specific names, to avoid ambiguity in future
21+
// when alternative net implementations are provided
22+
export const nets = {
23+
ssdMobilenet: detectionNet,
24+
faceLandmark68Net: landmarkNet,
25+
faceNet: recognitionNet,
26+
mtcnn: new Mtcnn()
27+
}
28+
1729
export function loadFaceDetectionModel(url: string) {
18-
return detectionNet.load(url)
30+
return nets.ssdMobilenet.load(url)
1931
}
2032

2133
export function loadFaceLandmarkModel(url: string) {
22-
return landmarkNet.load(url)
34+
return nets.faceLandmark68Net.load(url)
2335
}
2436

2537
export function loadFaceRecognitionModel(url: string) {
26-
return recognitionNet.load(url)
38+
return nets.faceNet.load(url)
39+
}
40+
41+
export function loadMtcnnModel(url: string) {
42+
return nets.mtcnn.load(url)
2743
}
2844

2945
export function loadModels(url: string) {
3046
return Promise.all([
3147
loadFaceDetectionModel(url),
3248
loadFaceLandmarkModel(url),
33-
loadFaceRecognitionModel(url)
49+
loadFaceRecognitionModel(url),
50+
loadMtcnnModel(url)
3451
])
3552
}
3653

@@ -39,19 +56,26 @@ export function locateFaces(
3956
minConfidence?: number,
4057
maxResults?: number
4158
): Promise<FaceDetection[]> {
42-
return detectionNet.locateFaces(input, minConfidence, maxResults)
59+
return nets.ssdMobilenet.locateFaces(input, minConfidence, maxResults)
4360
}
4461

4562
export function detectLandmarks(
4663
input: TNetInput
4764
): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
48-
return landmarkNet.detectLandmarks(input)
65+
return nets.faceLandmark68Net.detectLandmarks(input)
4966
}
5067

5168
export function computeFaceDescriptor(
5269
input: TNetInput
5370
): Promise<Float32Array | Float32Array[]> {
54-
return recognitionNet.computeFaceDescriptor(input)
71+
return nets.faceNet.computeFaceDescriptor(input)
72+
}
73+
74+
export function mtcnn(
75+
input: TNetInput,
76+
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
77+
): Promise<MtcnnResult[]> {
78+
return nets.mtcnn.forward(input, forwardParameters)
5579
}
5680

5781
export const allFaces: (

src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ export {
44
tf
55
}
66

7+
8+
export * from './FaceDetection';
79
export * from './FullFaceDescription';
810
export * from './NetInput';
911
export * from './Point';

src/mtcnn/Mtcnn.ts

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { NeuralNetwork } from '../commons/NeuralNetwork';
4-
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
4+
import { FaceDetection } from '../FaceDetection';
55
import { NetInput } from '../NetInput';
66
import { Point } from '../Point';
77
import { Rect } from '../Rect';
88
import { toNetInput } from '../toNetInput';
99
import { TNetInput } from '../types';
1010
import { bgrToRgbTensor } from './bgrToRgbTensor';
11+
import { CELL_SIZE } from './config';
1112
import { extractParams } from './extractParams';
1213
import { FaceLandmarks5 } from './FaceLandmarks5';
14+
import { getDefaultMtcnnForwardParams } from './getDefaultMtcnnForwardParams';
1315
import { getSizesForScale } from './getSizesForScale';
16+
import { loadQuantizedParams } from './loadQuantizedParams';
1417
import { pyramidDown } from './pyramidDown';
1518
import { stage1 } from './stage1';
1619
import { stage2 } from './stage2';
1720
import { stage3 } from './stage3';
18-
import { MtcnnResult, NetParams } from './types';
21+
import { MtcnnForwardParams, MtcnnResult, NetParams } from './types';
1922

2023
export class Mtcnn extends NeuralNetwork<NetParams> {
2124

@@ -25,10 +28,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
2528

2629
public async forwardInput(
2730
input: NetInput,
28-
minFaceSize: number = 20,
29-
scaleFactor: number = 0.709,
30-
maxNumScales: number = 10,
31-
scoreThresholds: number[] = [0.6, 0.7, 0.7]
31+
{ minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps } = getDefaultMtcnnForwardParams()
3232
): Promise<{ results: MtcnnResult[], stats: any }> {
3333

3434
const { params } = this
@@ -64,10 +64,10 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
6464

6565
const [height, width] = imgTensor.shape.slice(1)
6666

67-
const scales = pyramidDown(minFaceSize, scaleFactor, [height, width])
67+
const scales = scaleSteps || pyramidDown(minFaceSize, scaleFactor, [height, width])
6868
.filter(scale => {
6969
const sizes = getSizesForScale(scale, [height, width])
70-
return Math.min(sizes.width, sizes.height) > 48
70+
return Math.min(sizes.width, sizes.height) > CELL_SIZE
7171
})
7272
.slice(0, maxNumScales)
7373

@@ -124,38 +124,31 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
124124

125125
public async forward(
126126
input: TNetInput,
127-
minFaceSize: number = 20,
128-
scaleFactor: number = 0.709,
129-
maxNumScales: number = 10,
130-
scoreThresholds: number[] = [0.6, 0.7, 0.7]
127+
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
131128
): Promise<MtcnnResult[]> {
132129
return (
133130
await this.forwardInput(
134131
await toNetInput(input, true, true),
135-
minFaceSize,
136-
scaleFactor,
137-
maxNumScales,
138-
scoreThresholds
132+
forwardParameters
139133
)
140134
).results
141135
}
142136

143137
public async forwardWithStats(
144138
input: TNetInput,
145-
minFaceSize: number = 20,
146-
scaleFactor: number = 0.709,
147-
maxNumScales: number = 10,
148-
scoreThresholds: number[] = [0.6, 0.7, 0.7]
139+
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
149140
): Promise<{ results: MtcnnResult[], stats: any }> {
150141
return this.forwardInput(
151142
await toNetInput(input, true, true),
152-
minFaceSize,
153-
scaleFactor,
154-
maxNumScales,
155-
scoreThresholds
143+
forwardParameters
156144
)
157145
}
158146

147+
// none of the param tensors are quantized yet
148+
protected loadQuantizedParams(uri: string | undefined) {
149+
return loadQuantizedParams(uri)
150+
}
151+
159152
protected extractParams(weights: Float32Array) {
160153
return extractParams(weights)
161154
}

src/mtcnn/extractParams.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings
5555
const conv4 = extractConvParams(64, 128, 2, 'onet/conv4')
5656
const prelu4_alpha = extractPReluParams(128, 'onet/prelu4_alpha')
5757
const fc1 = extractFCParams(1152, 256, 'onet/fc1')
58-
const prelu5_alpha = extractPReluParams(256, 'onet/prelu4_alpha')
58+
const prelu5_alpha = extractPReluParams(256, 'onet/prelu5_alpha')
5959
const fc2_1 = extractFCParams(256, 2, 'onet/fc2_1')
6060
const fc2_2 = extractFCParams(256, 4, 'onet/fc2_2')
61-
const fc2_3 = extractFCParams(256, 10, 'onet/fc2_2')
61+
const fc2_3 = extractFCParams(256, 10, 'onet/fc2_3')
6262

6363
return { ...sharedParams, conv4, prelu4_alpha, fc1, prelu5_alpha, fc2_1, fc2_2, fc2_3 }
6464
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import { MtcnnForwardParams } from './types';
2+
3+
export function getDefaultMtcnnForwardParams(): MtcnnForwardParams {
4+
return {
5+
minFaceSize: 20,
6+
scaleFactor: 0.709,
7+
maxNumScales: 10,
8+
scoreThresholds: [0.6, 0.7, 0.7]
9+
}
10+
}

src/mtcnn/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { Mtcnn } from './Mtcnn';
22

33
export * from './Mtcnn';
4+
export * from './FaceLandmarks5';
45

5-
export function mtcnn(weights: Float32Array) {
6+
export function createMtcnn(weights: Float32Array) {
67
const net = new Mtcnn()
78
net.extractWeights(weights)
89
return net

0 commit comments

Comments
 (0)