Skip to content

Commit 0ea55f9

Browse files
fixed mtcnn face landmark positions
1 parent abf82d4 commit 0ea55f9

File tree

7 files changed

+35
-42
lines changed

7 files changed

+35
-42
lines changed
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,40 @@
11
import { FaceDetection } from './FaceDetection';
22
import { FaceLandmarks } from './FaceLandmarks';
3+
import { FaceLandmarks68 } from './FaceLandmarks68';
34

4-
export class FaceDetectionWithLandmarks {
5+
export class FaceDetectionWithLandmarks<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> {
56
private _detection: FaceDetection
6-
private _relativeLandmarks: FaceLandmarks
7+
private _unshiftedLandmarks: TFaceLandmarks
78

89
constructor(
910
detection: FaceDetection,
10-
relativeLandmarks: FaceLandmarks
11+
unshiftedLandmarks: TFaceLandmarks
1112
) {
1213
this._detection = detection
13-
this._relativeLandmarks = relativeLandmarks
14+
this._unshiftedLandmarks = unshiftedLandmarks
1415
}
1516

1617
public get detection(): FaceDetection { return this._detection }
17-
public get relativeLandmarks(): FaceLandmarks { return this._relativeLandmarks }
18+
public get unshiftedLandmarks(): TFaceLandmarks { return this._unshiftedLandmarks }
1819

1920
public get alignedRect(): FaceDetection {
2021
const rect = this.landmarks.align()
2122
const { imageDims } = this.detection
2223
return new FaceDetection(this._detection.score, rect.rescale(imageDims.reverse()), imageDims)
2324
}
2425

25-
public get landmarks(): FaceLandmarks {
26+
public get landmarks(): TFaceLandmarks {
2627
const { x, y } = this.detection.box
27-
return this._relativeLandmarks.shift(x, y)
28+
return this._unshiftedLandmarks.shiftBy(x, y)
2829
}
2930

30-
public forSize(width: number, height: number): FaceDetectionWithLandmarks {
31+
// aliases for backward compatibily
32+
get faceDetection(): FaceDetection { return this.detection }
33+
get faceLandmarks(): TFaceLandmarks { return this.landmarks }
34+
35+
public forSize(width: number, height: number): FaceDetectionWithLandmarks<TFaceLandmarks> {
3136
const resizedDetection = this._detection.forSize(width, height)
32-
const resizedLandmarks = this._relativeLandmarks.forSize(resizedDetection.box.width, resizedDetection.box.height)
33-
return new FaceDetectionWithLandmarks(resizedDetection, resizedLandmarks)
37+
const resizedLandmarks = this._unshiftedLandmarks.forSize<TFaceLandmarks>(resizedDetection.box.width, resizedDetection.box.height)
38+
return new FaceDetectionWithLandmarks<TFaceLandmarks>(resizedDetection, resizedLandmarks)
3439
}
3540
}

src/classes/FaceLandmarks.ts

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@ export class FaceLandmarks {
2525
)
2626
}
2727

28-
public getShift(): Point {
29-
return new Point(this._shift.x, this._shift.y)
30-
}
31-
28+
public get shift(): Point { return new Point(this._shift.x, this._shift.y) }
3229
public get imageWidth(): number { return this._imgDims.width }
3330
public get imageHeight(): number { return this._imgDims.height }
3431
public get positions(): Point[] { return this._positions }
@@ -45,7 +42,7 @@ export class FaceLandmarks {
4542
)
4643
}
4744

48-
public shift<T extends FaceLandmarks>(x: number, y: number): T {
45+
public shiftBy<T extends FaceLandmarks>(x: number, y: number): T {
4946
return new (this.constructor as any)(
5047
this.relativePositions,
5148
this._imgDims,
@@ -54,7 +51,7 @@ export class FaceLandmarks {
5451
}
5552

5653
public shiftByPoint<T extends FaceLandmarks>(pt: Point): T {
57-
return this.shift(pt.x, pt.y)
54+
return this.shiftBy(pt.x, pt.y)
5855
}
5956

6057
/**
@@ -76,7 +73,7 @@ export class FaceLandmarks {
7673
? detection.box.floor()
7774
: detection
7875

79-
return this.shift(box.x, box.y).align()
76+
return this.shiftBy(box.x, box.y).align()
8077
}
8178

8279
const centers = this.getRefPointsForAlignment()

src/classes/FullFaceDescription.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
import { FaceDetection } from './FaceDetection';
22
import { FaceDetectionWithLandmarks } from './FaceDetectionWithLandmarks';
33
import { FaceLandmarks } from './FaceLandmarks';
4+
import { FaceLandmarks68 } from './FaceLandmarks68';
45

5-
export class FullFaceDescription extends FaceDetectionWithLandmarks {
6+
export class FullFaceDescription<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68> extends FaceDetectionWithLandmarks<TFaceLandmarks> {
67
private _descriptor: Float32Array
78

89
constructor(
910
detection: FaceDetection,
10-
landmarks: FaceLandmarks,
11+
unshiftedLandmarks: TFaceLandmarks,
1112
descriptor: Float32Array
1213
) {
13-
super(detection, landmarks)
14+
super(detection, unshiftedLandmarks)
1415
this._descriptor = descriptor
1516
}
1617

1718
public get descriptor(): Float32Array {
1819
return this._descriptor
1920
}
2021

21-
public forSize(width: number, height: number): FullFaceDescription {
22+
public forSize(width: number, height: number): FullFaceDescription<TFaceLandmarks> {
2223
const { detection, landmarks } = super.forSize(width, height)
23-
return new FullFaceDescription(detection, landmarks, this.descriptor)
24+
return new FullFaceDescription<TFaceLandmarks>(detection, landmarks, this.descriptor)
2425
}
2526
}

src/globalApi/nets.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ import { TNetInput } from 'tfjs-image-recognition-base';
22
import { ITinyYolov2Options } from 'tfjs-tiny-yolov2';
33

44
import { FaceDetection } from '../classes/FaceDetection';
5+
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
56
import { FaceLandmarks68 } from '../classes/FaceLandmarks68';
67
import { FaceLandmark68Net } from '../faceLandmarkNet/FaceLandmark68Net';
78
import { FaceLandmark68TinyNet } from '../faceLandmarkNet/FaceLandmark68TinyNet';
89
import { FaceRecognitionNet } from '../faceRecognitionNet/FaceRecognitionNet';
910
import { Mtcnn } from '../mtcnn/Mtcnn';
1011
import { MtcnnOptions } from '../mtcnn/MtcnnOptions';
11-
import { MtcnnResult } from '../mtcnn/MtcnnResult';
1212
import { SsdMobilenetv1 } from '../ssdMobilenetv1/SsdMobilenetv1';
1313
import { SsdMobilenetv1Options } from '../ssdMobilenetv1/SsdMobilenetv1Options';
1414
import { TinyFaceDetector } from '../tinyFaceDetector/TinyFaceDetector';
@@ -63,7 +63,7 @@ export const tinyYolov2 = (input: TNetInput, options: ITinyYolov2Options): Promi
6363
* @param options (optional, default: see MtcnnOptions constructor for default parameters).
6464
* @returns Bounding box of each face with score and 5 point face landmarks.
6565
*/
66-
export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<MtcnnResult[]> =>
66+
export const mtcnn = (input: TNetInput, options: MtcnnOptions): Promise<FaceDetectionWithLandmarks[]> =>
6767
nets.mtcnn.forward(input, options)
6868

6969
/**

src/mtcnn/Mtcnn.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ import { extractParams } from './extractParams';
99
import { getSizesForScale } from './getSizesForScale';
1010
import { loadQuantizedParams } from './loadQuantizedParams';
1111
import { IMtcnnOptions, MtcnnOptions } from './MtcnnOptions';
12-
import { MtcnnResult } from './MtcnnResult';
1312
import { pyramidDown } from './pyramidDown';
1413
import { stage1 } from './stage1';
1514
import { stage2 } from './stage2';
1615
import { stage3 } from './stage3';
1716
import { NetParams } from './types';
17+
import { FaceDetectionWithLandmarks } from '../classes/FaceDetectionWithLandmarks';
1818

1919
export class Mtcnn extends NeuralNetwork<NetParams> {
2020

@@ -25,7 +25,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
2525
public async forwardInput(
2626
input: NetInput,
2727
forwardParams: IMtcnnOptions = {}
28-
): Promise<{ results: MtcnnResult[], stats: any }> {
28+
): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
2929

3030
const { params } = this
3131

@@ -101,7 +101,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
101101
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)
102102
stats.total_stage3 = Date.now() - ts
103103

104-
const results = out3.boxes.map((box, idx) => new MtcnnResult(
104+
const results = out3.boxes.map((box, idx) => new FaceDetectionWithLandmarks(
105105
new FaceDetection(
106106
out3.scores[idx],
107107
new Rect(
@@ -116,8 +116,8 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
116116
}
117117
),
118118
new FaceLandmarks5(
119-
out3.points[idx].map(pt => pt.div(new Point(width, height))),
120-
{ width, height }
119+
out3.points[idx].map(pt => pt.sub(new Point(box.left, box.top)).div(new Point(box.width, box.height))),
120+
{ width: box.width, height: box.height }
121121
)
122122
))
123123

@@ -127,7 +127,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
127127
public async forward(
128128
input: TNetInput,
129129
forwardParams: IMtcnnOptions = {}
130-
): Promise<MtcnnResult[]> {
130+
): Promise<FaceDetectionWithLandmarks[]> {
131131
return (
132132
await this.forwardInput(
133133
await toNetInput(input),
@@ -139,7 +139,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
139139
public async forwardWithStats(
140140
input: TNetInput,
141141
forwardParams: IMtcnnOptions = {}
142-
): Promise<{ results: MtcnnResult[], stats: any }> {
142+
): Promise<{ results: FaceDetectionWithLandmarks[], stats: any }> {
143143
return this.forwardInput(
144144
await toNetInput(input),
145145
forwardParams

src/mtcnn/MtcnnResult.ts

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/ssdMobilenetv1/SsdMobilenetv1.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export class SsdMobilenetv1 extends NeuralNetwork<NetParams> {
4646

4747
public async locateFaces(
4848
input: TNetInput,
49-
options: ISsdMobilenetv1Options
49+
options: ISsdMobilenetv1Options = {}
5050
): Promise<FaceDetection[]> {
5151

5252
const { maxResults, minConfidence } = new SsdMobilenetv1Options(options)

0 commit comments

Comments
 (0)