Skip to content

Commit 930f85b

Browse files
face alignment from 5 point face landmarks + allFacesMtcnn
1 parent 049997b commit 930f85b

File tree

6 files changed

+155
-101
lines changed

6 files changed

+155
-101
lines changed

src/FaceLandmarks.ts

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
import { Point } from './Point';
1+
import { getCenterPoint } from './commons/getCenterPoint';
2+
import { FaceDetection } from './FaceDetection';
3+
import { IPoint, Point } from './Point';
4+
import { Rect } from './Rect';
25
import { Dimensions } from './types';
36

7+
// face alignment constants
8+
const relX = 0.5
9+
const relY = 0.43
10+
const relScale = 0.45
11+
412
export class FaceLandmarks {
513
protected _imageWidth: number
614
protected _imageHeight: number
@@ -42,4 +50,65 @@ export class FaceLandmarks {
4250
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
4351
)
4452
}
53+
54+
public forSize<T extends FaceLandmarks>(width: number, height: number): T {
55+
return new (this.constructor as any)(
56+
this.getRelativePositions(),
57+
{ width, height }
58+
)
59+
}
60+
61+
public shift<T extends FaceLandmarks>(x: number, y: number): T {
62+
return new (this.constructor as any)(
63+
this.getRelativePositions(),
64+
{ width: this._imageWidth, height: this._imageHeight },
65+
new Point(x, y)
66+
)
67+
}
68+
69+
public shiftByPoint<T extends FaceLandmarks>(pt: IPoint): T {
70+
return this.shift(pt.x, pt.y)
71+
}
72+
73+
/**
74+
* Aligns the face landmarks after face detection from the relative positions of the faces
75+
* bounding box, or it's current shift. This function should be used to align the face images
76+
* after face detection has been performed, before they are passed to the face recognition net.
77+
* This will make the computed face descriptor more accurate.
78+
*
79+
* @param detection (optional) The bounding box of the face or the face detection result. If
80+
* no argument was passed the position of the face landmarks are assumed to be relative to
81+
* it's current shift.
82+
* @returns The bounding box of the aligned face.
83+
*/
84+
public align(
85+
detection?: FaceDetection | Rect
86+
): Rect {
87+
if (detection) {
88+
const box = detection instanceof FaceDetection
89+
? detection.getBox().floor()
90+
: detection
91+
92+
return this.shift(box.x, box.y).align()
93+
}
94+
95+
const centers = this.getRefPointsForAlignment()
96+
97+
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
98+
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
99+
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
100+
101+
const size = Math.floor(eyeToMouthDist / relScale)
102+
103+
const refPoint = getCenterPoint(centers)
104+
// TODO: pad in case rectangle is out of image bounds
105+
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
106+
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
107+
108+
return new Rect(x, y, Math.min(size, this._imageWidth - x), Math.min(size, this._imageHeight - y))
109+
}
110+
111+
protected getRefPointsForAlignment(): Point[] {
112+
throw new Error('getRefPointsForAlignment not implemented by base class')
113+
}
45114
}

src/FullFaceDescription.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
import { FaceDetection } from './FaceDetection';
2-
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
2+
import { FaceLandmarks } from './FaceLandmarks';
33

44
export class FullFaceDescription {
55
constructor(
66
private _detection: FaceDetection,
7-
private _landmarks: FaceLandmarks68,
7+
private _landmarks: FaceLandmarks,
88
private _descriptor: Float32Array
99
) {}
1010

1111
public get detection(): FaceDetection {
1212
return this._detection
1313
}
1414

15-
public get landmarks(): FaceLandmarks68 {
15+
public get landmarks(): FaceLandmarks {
1616
return this._landmarks
1717
}
1818

src/allFacesFactory.ts

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@ import { extractFaceTensors } from './extractFaceTensors';
22
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
33
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
44
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
5-
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
65
import { FullFaceDescription } from './FullFaceDescription';
6+
import { Mtcnn } from './mtcnn/Mtcnn';
7+
import { MtcnnForwardParams } from './mtcnn/types';
8+
import { Rect } from './Rect';
79
import { TNetInput } from './types';
810

911
export function allFacesFactory(
1012
detectionNet: FaceDetectionNet,
1113
landmarkNet: FaceLandmarkNet,
12-
recognitionNet: FaceRecognitionNet
14+
computeDescriptors: (input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) => Promise<Float32Array[]>
1315
) {
1416
return async function(
1517
input: TNetInput,
@@ -32,20 +34,42 @@ export function allFacesFactory(
3234
const alignedFaceBoxes = faceLandmarksByFace.map(
3335
(landmarks, i) => landmarks.align(detections[i].getBox())
3436
)
35-
const alignedFaceTensors = await extractFaceTensors(input, alignedFaceBoxes)
3637

37-
const descriptors = useBatchProcessing
38-
? await recognitionNet.computeFaceDescriptor(alignedFaceTensors) as Float32Array[]
39-
: await Promise.all(alignedFaceTensors.map(
40-
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
41-
)) as Float32Array[]
42-
43-
alignedFaceTensors.forEach(t => t.dispose())
38+
const descriptors = await computeDescriptors(input, alignedFaceBoxes, useBatchProcessing)
4439

4540
return detections.map((detection, i) =>
4641
new FullFaceDescription(
4742
detection,
48-
faceLandmarksByFace[i].shiftByPoint(detection.getBox()),
43+
faceLandmarksByFace[i].shiftByPoint<FaceLandmarks68>(detection.getBox()),
44+
descriptors[i]
45+
)
46+
)
47+
48+
}
49+
}
50+
51+
export function allFacesMtcnnFactory(
52+
mtcnn: Mtcnn,
53+
computeDescriptors: (input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) => Promise<Float32Array[]>
54+
) {
55+
return async function(
56+
input: TNetInput,
57+
mtcnnForwardParams: MtcnnForwardParams,
58+
useBatchProcessing: boolean = false
59+
): Promise<FullFaceDescription[]> {
60+
61+
const results = await mtcnn.forward(input, mtcnnForwardParams)
62+
63+
const alignedFaceBoxes = results.map(
64+
({ faceLandmarks }) => faceLandmarks.align()
65+
)
66+
67+
const descriptors = await computeDescriptors(input, alignedFaceBoxes, useBatchProcessing)
68+
69+
return results.map(({ faceDetection, faceLandmarks }, i) =>
70+
new FullFaceDescription(
71+
faceDetection,
72+
faceLandmarks,
4973
descriptors[i]
5074
)
5175
)
Lines changed: 3 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import { getCenterPoint } from '../commons/getCenterPoint';
22
import { FaceDetection } from '../FaceDetection';
33
import { FaceLandmarks } from '../FaceLandmarks';
4-
import { IPoint, Point } from '../Point';
4+
import { Point } from '../Point';
55
import { Rect } from '../Rect';
66

7-
// face alignment constants
8-
const relX = 0.5
9-
const relY = 0.43
10-
const relScale = 0.45
11-
127
export class FaceLandmarks68 extends FaceLandmarks {
138
public getJawOutline(): Point[] {
149
return this._faceLandmarks.slice(0, 17)
@@ -38,64 +33,11 @@ export class FaceLandmarks68 extends FaceLandmarks {
3833
return this._faceLandmarks.slice(48, 68)
3934
}
4035

41-
public forSize(width: number, height: number): FaceLandmarks68 {
42-
return new FaceLandmarks68(
43-
this.getRelativePositions(),
44-
{ width, height }
45-
)
46-
}
47-
48-
public shift(x: number, y: number): FaceLandmarks68 {
49-
return new FaceLandmarks68(
50-
this.getRelativePositions(),
51-
{ width: this._imageWidth, height: this._imageHeight },
52-
new Point(x, y)
53-
)
54-
}
55-
56-
public shiftByPoint(pt: IPoint): FaceLandmarks68 {
57-
return this.shift(pt.x, pt.y)
58-
}
59-
60-
/**
61-
* Aligns the face landmarks after face detection from the relative positions of the faces
62-
* bounding box, or it's current shift. This function should be used to align the face images
63-
* after face detection has been performed, before they are passed to the face recognition net.
64-
* This will make the computed face descriptor more accurate.
65-
*
66-
* @param detection (optional) The bounding box of the face or the face detection result. If
67-
* no argument was passed the position of the face landmarks are assumed to be relative to
68-
* it's current shift.
69-
* @returns The bounding box of the aligned face.
70-
*/
71-
public align(
72-
detection?: FaceDetection | Rect
73-
): Rect {
74-
if (detection) {
75-
const box = detection instanceof FaceDetection
76-
? detection.getBox().floor()
77-
: detection
78-
79-
return this.shift(box.x, box.y).align()
80-
}
81-
82-
const centers = [
36+
protected getRefPointsForAlignment(): Point[] {
37+
return [
8338
this.getLeftEye(),
8439
this.getRightEye(),
8540
this.getMouth()
8641
].map(getCenterPoint)
87-
88-
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
89-
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
90-
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
91-
92-
const size = Math.floor(eyeToMouthDist / relScale)
93-
94-
const refPoint = getCenterPoint(centers)
95-
// TODO: pad in case rectangle is out of image bounds
96-
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
97-
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
98-
99-
return new Rect(x, y, size, size)
10042
}
10143
}

src/globalApi.ts

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { allFacesFactory } from './allFacesFactory';
3+
import { allFacesFactory, allFacesMtcnnFactory } from './allFacesFactory';
4+
import { extractFaceTensors } from './extractFaceTensors';
45
import { FaceDetection } from './FaceDetection';
56
import { FaceDetectionNet } from './faceDetectionNet/FaceDetectionNet';
67
import { FaceLandmarkNet } from './faceLandmarkNet/FaceLandmarkNet';
78
import { FaceLandmarks68 } from './faceLandmarkNet/FaceLandmarks68';
89
import { FaceRecognitionNet } from './faceRecognitionNet/FaceRecognitionNet';
910
import { FullFaceDescription } from './FullFaceDescription';
10-
import { getDefaultMtcnnForwardParams } from './mtcnn/getDefaultMtcnnForwardParams';
1111
import { Mtcnn } from './mtcnn/Mtcnn';
1212
import { MtcnnForwardParams, MtcnnResult } from './mtcnn/types';
1313
import { NetInput } from './NetInput';
14+
import { Rect } from './Rect';
1415
import { TNetInput } from './types';
1516

1617
export const detectionNet = new FaceDetectionNet()
@@ -22,7 +23,7 @@ export const recognitionNet = new FaceRecognitionNet()
2223
export const nets = {
2324
ssdMobilenet: detectionNet,
2425
faceLandmark68Net: landmarkNet,
25-
faceNet: recognitionNet,
26+
faceRecognitionNet: recognitionNet,
2627
mtcnn: new Mtcnn()
2728
}
2829

@@ -35,7 +36,7 @@ export function loadFaceLandmarkModel(url: string) {
3536
}
3637

3738
export function loadFaceRecognitionModel(url: string) {
38-
return nets.faceNet.load(url)
39+
return nets.faceRecognitionNet.load(url)
3940
}
4041

4142
export function loadMtcnnModel(url: string) {
@@ -68,7 +69,7 @@ export function detectLandmarks(
6869
export function computeFaceDescriptor(
6970
input: TNetInput
7071
): Promise<Float32Array | Float32Array[]> {
71-
return nets.faceNet.computeFaceDescriptor(input)
72+
return nets.faceRecognitionNet.computeFaceDescriptor(input)
7273
}
7374

7475
export function mtcnn(
@@ -85,5 +86,32 @@ export const allFaces: (
8586
) => Promise<FullFaceDescription[]> = allFacesFactory(
8687
detectionNet,
8788
landmarkNet,
88-
recognitionNet
89-
)
89+
computeDescriptorsFactory(nets.faceRecognitionNet)
90+
)
91+
92+
export const allFacesMtcnn: (
93+
input: tf.Tensor | NetInput | TNetInput,
94+
mtcnnForwardParams: MtcnnForwardParams,
95+
useBatchProcessing?: boolean
96+
) => Promise<FullFaceDescription[]> = allFacesMtcnnFactory(
97+
nets.mtcnn,
98+
computeDescriptorsFactory(nets.faceRecognitionNet)
99+
)
100+
101+
function computeDescriptorsFactory(
102+
recognitionNet: FaceRecognitionNet
103+
) {
104+
return async function(input: TNetInput, alignedFaceBoxes: Rect[], useBatchProcessing: boolean) {
105+
const alignedFaceTensors = await extractFaceTensors(input, alignedFaceBoxes)
106+
107+
const descriptors = useBatchProcessing
108+
? await recognitionNet.computeFaceDescriptor(alignedFaceTensors) as Float32Array[]
109+
: await Promise.all(alignedFaceTensors.map(
110+
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
111+
)) as Float32Array[]
112+
113+
alignedFaceTensors.forEach(t => t.dispose())
114+
115+
return descriptors
116+
}
117+
}

src/mtcnn/FaceLandmarks5.ts

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,15 @@
1+
import { getCenterPoint } from '../commons/getCenterPoint';
12
import { FaceLandmarks } from '../FaceLandmarks';
2-
import { IPoint, Point } from '../Point';
3+
import { Point } from '../Point';
34

45
export class FaceLandmarks5 extends FaceLandmarks {
56

6-
public forSize(width: number, height: number): FaceLandmarks5 {
7-
return new FaceLandmarks5(
8-
this.getRelativePositions(),
9-
{ width, height }
10-
)
11-
}
12-
13-
public shift(x: number, y: number): FaceLandmarks5 {
14-
return new FaceLandmarks5(
15-
this.getRelativePositions(),
16-
{ width: this._imageWidth, height: this._imageHeight },
17-
new Point(x, y)
18-
)
19-
}
20-
21-
public shiftByPoint(pt: IPoint): FaceLandmarks5 {
22-
return this.shift(pt.x, pt.y)
7+
protected getRefPointsForAlignment(): Point[] {
8+
const pts = this.getPositions()
9+
return [
10+
pts[0],
11+
pts[1],
12+
getCenterPoint([pts[3], pts[4]])
13+
]
2314
}
2415
}

0 commit comments

Comments
 (0)