Skip to content

Commit e04770c

Browse files
finished mtcnn implementation
1 parent 5e957ec commit e04770c

18 files changed

+326
-147
lines changed

src/FaceLandmarks.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { Point } from './Point';
2+
import { Dimensions } from './types';
3+
4+
export class FaceLandmarks {
5+
protected _imageWidth: number
6+
protected _imageHeight: number
7+
protected _shift: Point
8+
protected _faceLandmarks: Point[]
9+
10+
constructor(
11+
relativeFaceLandmarkPositions: Point[],
12+
imageDims: Dimensions,
13+
shift: Point = new Point(0, 0)
14+
) {
15+
const { width, height } = imageDims
16+
this._imageWidth = width
17+
this._imageHeight = height
18+
this._shift = shift
19+
this._faceLandmarks = relativeFaceLandmarkPositions.map(
20+
pt => pt.mul(new Point(width, height)).add(shift)
21+
)
22+
}
23+
24+
public getShift(): Point {
25+
return new Point(this._shift.x, this._shift.y)
26+
}
27+
28+
public getImageWidth(): number {
29+
return this._imageWidth
30+
}
31+
32+
public getImageHeight(): number {
33+
return this._imageHeight
34+
}
35+
36+
public getPositions(): Point[] {
37+
return this._faceLandmarks
38+
}
39+
40+
public getRelativePositions(): Point[] {
41+
return this._faceLandmarks.map(
42+
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
43+
)
44+
}
45+
}

src/FullFaceDescription.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import { FaceDetection } from './faceDetectionNet/FaceDetection';
2-
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
2+
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
33

44
export class FullFaceDescription {
55
constructor(
66
private _detection: FaceDetection,
7-
private _landmarks: FaceLandmarks,
7+
private _landmarks: FaceLandmarks68,
88
private _descriptor: Float32Array
99
) {}
1010

1111
public get detection(): FaceDetection {
1212
return this._detection
1313
}
1414

15-
public get landmarks(): FaceLandmarks {
15+
public get landmarks(): FaceLandmarks68 {
1616
return this._landmarks
1717
}
1818

src/allFacesFactory.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { extractFaceTensors } from './extractFaceTensors';
22
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
33
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
4-
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
4+
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
55
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
66
import { FullFaceDescription } from './FullFaceDescription';
77
import { TNetInput } from './types';
@@ -22,10 +22,10 @@ export function allFacesFactory(
2222
const faceTensors = await extractFaceTensors(input, detections)
2323

2424
const faceLandmarksByFace = useBatchProcessing
25-
? await landmarkNet.detectLandmarks(faceTensors) as FaceLandmarks[]
25+
? await landmarkNet.detectLandmarks(faceTensors) as FaceLandmarks68[]
2626
: await Promise.all(faceTensors.map(
2727
faceTensor => landmarkNet.detectLandmarks(faceTensor)
28-
)) as FaceLandmarks[]
28+
)) as FaceLandmarks68[]
2929

3030
faceTensors.forEach(t => t.dispose())
3131

src/drawing/index.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
2-
import { FaceLandmarks } from '../faceLandmarkNet/FaceLandmarks';
2+
import { FaceLandmarks68 } from '../faceLandmarkNet';
3+
import { FaceLandmarks } from '../FaceLandmarks';
34
import { Point } from '../Point';
45
import { getContext2dOrThrow, resolveInput, round } from '../utils';
56
import { DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
@@ -150,7 +151,7 @@ export function drawLandmarks(
150151
const faceLandmarksArray = Array.isArray(faceLandmarks) ? faceLandmarks : [faceLandmarks]
151152

152153
faceLandmarksArray.forEach(landmarks => {
153-
if (drawLines) {
154+
if (drawLines && landmarks instanceof FaceLandmarks68) {
154155
ctx.strokeStyle = color
155156
ctx.lineWidth = lineWidth
156157
drawContour(ctx, landmarks.getJawOutline())

src/faceLandmarkNet/FaceLandmarkNet.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { toNetInput } from '../toNetInput';
99
import { TNetInput } from '../types';
1010
import { isEven } from '../utils';
1111
import { extractParams } from './extractParams';
12-
import { FaceLandmarks } from './FaceLandmarks';
12+
import { FaceLandmarks68 } from './FaceLandmarks68';
1313
import { fullyConnectedLayer } from './fullyConnectedLayer';
1414
import { loadQuantizedParams } from './loadQuantizedParams';
1515
import { NetParams } from './types';
@@ -93,7 +93,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
9393
return this.forwardInput(await toNetInput(input, true))
9494
}
9595

96-
public async detectLandmarks(input: TNetInput): Promise<FaceLandmarks | FaceLandmarks[]> {
96+
public async detectLandmarks(input: TNetInput): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
9797
const netInput = await toNetInput(input, true)
9898

9999
const landmarkTensors = tf.tidy(
@@ -106,7 +106,7 @@ export class FaceLandmarkNet extends NeuralNetwork<NetParams> {
106106
const xCoords = landmarksArray.filter((_, i) => isEven(i))
107107
const yCoords = landmarksArray.filter((_, i) => !isEven(i))
108108

109-
return new FaceLandmarks(
109+
return new FaceLandmarks68(
110110
Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),
111111
{
112112
height: netInput.getInputHeight(batchIdx),

src/faceLandmarkNet/FaceLandmarks.ts renamed to src/faceLandmarkNet/FaceLandmarks68.ts

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { getCenterPoint } from '../commons/getCenterPoint';
22
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
3+
import { FaceLandmarks } from '../FaceLandmarks';
34
import { IPoint, Point } from '../Point';
45
import { Rect } from '../Rect';
56
import { Dimensions } from '../types';
@@ -9,48 +10,7 @@ const relX = 0.5
910
const relY = 0.43
1011
const relScale = 0.45
1112

12-
export class FaceLandmarks {
13-
private _imageWidth: number
14-
private _imageHeight: number
15-
private _shift: Point
16-
private _faceLandmarks: Point[]
17-
18-
constructor(
19-
relativeFaceLandmarkPositions: Point[],
20-
imageDims: Dimensions,
21-
shift: Point = new Point(0, 0)
22-
) {
23-
const { width, height } = imageDims
24-
this._imageWidth = width
25-
this._imageHeight = height
26-
this._shift = shift
27-
this._faceLandmarks = relativeFaceLandmarkPositions.map(
28-
pt => pt.mul(new Point(width, height)).add(shift)
29-
)
30-
}
31-
32-
public getShift(): Point {
33-
return new Point(this._shift.x, this._shift.y)
34-
}
35-
36-
public getImageWidth(): number {
37-
return this._imageWidth
38-
}
39-
40-
public getImageHeight(): number {
41-
return this._imageHeight
42-
}
43-
44-
public getPositions(): Point[] {
45-
return this._faceLandmarks
46-
}
47-
48-
public getRelativePositions(): Point[] {
49-
return this._faceLandmarks.map(
50-
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
51-
)
52-
}
53-
13+
export class FaceLandmarks68 extends FaceLandmarks {
5414
public getJawOutline(): Point[] {
5515
return this._faceLandmarks.slice(0, 17)
5616
}
@@ -79,22 +39,22 @@ export class FaceLandmarks {
7939
return this._faceLandmarks.slice(48, 68)
8040
}
8141

82-
public forSize(width: number, height: number): FaceLandmarks {
83-
return new FaceLandmarks(
42+
public forSize(width: number, height: number): FaceLandmarks68 {
43+
return new FaceLandmarks68(
8444
this.getRelativePositions(),
8545
{ width, height }
8646
)
8747
}
8848

89-
public shift(x: number, y: number): FaceLandmarks {
90-
return new FaceLandmarks(
49+
public shift(x: number, y: number): FaceLandmarks68 {
50+
return new FaceLandmarks68(
9151
this.getRelativePositions(),
9252
{ width: this._imageWidth, height: this._imageHeight },
9353
new Point(x, y)
9454
)
9555
}
9656

97-
public shiftByPoint(pt: IPoint): FaceLandmarks {
57+
public shiftByPoint(pt: IPoint): FaceLandmarks68 {
9858
return this.shift(pt.x, pt.y)
9959
}
10060

src/faceLandmarkNet/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { FaceLandmarkNet } from './FaceLandmarkNet';
22

33
export * from './FaceLandmarkNet';
4-
export * from './FaceLandmarks';
4+
export * from './FaceLandmarks68';
55

66
export function faceLandmarkNet(weights: Float32Array) {
77
const net = new FaceLandmarkNet()

src/globalApi.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { allFacesFactory } from './allFacesFactory';
44
import { FaceDetection } from './faceDetectionNet/FaceDetection';
55
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
66
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
7-
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
7+
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
88
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
99
import { FullFaceDescription } from './FullFaceDescription';
1010
import { NetInput } from './NetInput';
@@ -44,7 +44,7 @@ export function locateFaces(
4444

4545
export function detectLandmarks(
4646
input: TNetInput
47-
): Promise<FaceLandmarks | FaceLandmarks[]> {
47+
): Promise<FaceLandmarks68 | FaceLandmarks68[]> {
4848
return landmarkNet.detectLandmarks(input)
4949
}
5050

src/mtcnn/BoundingBox.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,13 @@ export class BoundingBox {
8989

9090
return { dy, edy, dx, edx, y, ey, x, ex, w, h }
9191
}
92+
93+
public calibrate(region: BoundingBox) {
94+
return new BoundingBox(
95+
this.left + (region.left * this.width),
96+
this.top + (region.top * this.height),
97+
this.right + (region.right * this.width),
98+
this.bottom + (region.bottom * this.height)
99+
).toSquare().round()
100+
}
92101
}

src/mtcnn/FaceLandmarks5.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { FaceLandmarks } from '../FaceLandmarks';
2+
import { IPoint, Point } from '../Point';
3+
4+
export class FaceLandmarks5 extends FaceLandmarks {
5+
6+
public forSize(width: number, height: number): FaceLandmarks5 {
7+
return new FaceLandmarks5(
8+
this.getRelativePositions(),
9+
{ width, height }
10+
)
11+
}
12+
13+
public shift(x: number, y: number): FaceLandmarks5 {
14+
return new FaceLandmarks5(
15+
this.getRelativePositions(),
16+
{ width: this._imageWidth, height: this._imageHeight },
17+
new Point(x, y)
18+
)
19+
}
20+
21+
public shiftByPoint(pt: IPoint): FaceLandmarks5 {
22+
return this.shift(pt.x, pt.y)
23+
}
24+
}

src/mtcnn/Mtcnn.ts

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { NeuralNetwork } from '../commons/NeuralNetwork';
4+
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
45
import { NetInput } from '../NetInput';
6+
import { Point } from '../Point';
7+
import { Rect } from '../Rect';
58
import { toNetInput } from '../toNetInput';
69
import { TNetInput } from '../types';
710
import { bgrToRgbTensor } from './bgrToRgbTensor';
811
import { extractParams } from './extractParams';
12+
import { FaceLandmarks5 } from './FaceLandmarks5';
913
import { pyramidDown } from './pyramidDown';
1014
import { stage1 } from './stage1';
1115
import { stage2 } from './stage2';
16+
import { stage3 } from './stage3';
1217
import { NetParams } from './types';
1318

1419
export class Mtcnn extends NeuralNetwork<NetParams> {
@@ -22,7 +27,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
2227
minFaceSize: number = 20,
2328
scaleFactor: number = 0.709,
2429
scoreThresholds: number[] = [0.6, 0.7, 0.7]
25-
): Promise<tf.Tensor2D> {
30+
): Promise<any> {
2631

2732
const { params } = this
2833

@@ -43,19 +48,46 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
4348
)
4449
)
4550

46-
const scales = pyramidDown(minFaceSize, scaleFactor, imgTensor.shape.slice(1))
51+
const [height, width] = imgTensor.shape.slice(1)
52+
53+
const scales = pyramidDown(minFaceSize, scaleFactor, [height, width])
4754
const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet)
4855

4956
// using the inputCanvas to extract and resize the image patches, since it is faster
5057
// than doing this on the gpu
51-
const out2 = await stage2(inputCanvas, out1, scoreThresholds[1], params.rnet)
52-
53-
58+
const out2 = await stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet)
59+
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet)
5460

5561
imgTensor.dispose()
5662
input.dispose()
5763

58-
return tf.tensor2d([0], [1, 1])
64+
const faceDetections = out3.boxes.map((box, idx) =>
65+
new FaceDetection(
66+
out3.scores[idx],
67+
new Rect(
68+
box.left / width,
69+
box.top / height,
70+
box.width / width,
71+
box.height / height
72+
),
73+
{
74+
height,
75+
width
76+
}
77+
)
78+
)
79+
80+
const faceLandmarks = out3.points.map(pts =>
81+
new FaceLandmarks5(
82+
pts.map(pt => pt.div(new Point(width, height))),
83+
{ width, height }
84+
)
85+
)
86+
87+
return {
88+
faceDetections,
89+
faceLandmarks
90+
}
5991
}
6092

6193
public async forward(

0 commit comments

Comments
 (0)