Skip to content

Commit 93959e1

Browse files
implemented face alignment from landmarks
1 parent 5f46b72 commit 93959e1

File tree

7 files changed

+129
-26
lines changed

7 files changed

+129
-26
lines changed

src/Point.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,16 @@ export class Point implements IPoint {
2727
public div(pt: IPoint): Point {
2828
return new Point(this.x / pt.x, this.y / pt.y)
2929
}
30+
31+
public abs(): Point {
32+
return new Point(Math.abs(this.x), Math.abs(this.y))
33+
}
34+
35+
public magnitude(): number {
36+
return Math.sqrt(Math.pow(this.x, 2) + Math.pow(this.y, 2))
37+
}
38+
39+
public floor(): Point {
40+
return new Point(Math.floor(this.x), Math.floor(this.y))
41+
}
3042
}

src/commons/getCenterPoint.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import { Point } from '../Point';
2+
3+
export function getCenterPoint(pts: Point[]): Point {
4+
return pts.reduce((sum, pt) => sum.add(pt), new Point(0, 0))
5+
.div(new Point(pts.length, pts.length))
6+
}

src/extractFaceTensors.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,37 @@ import * as tf from '@tensorflow/tfjs-core';
33
import { FaceDetection } from './faceDetectionNet/FaceDetection';
44
import { getImageTensor } from './getImageTensor';
55
import { NetInput } from './NetInput';
6+
import { Rect } from './Rect';
67
import { TNetInput } from './types';
78

89
/**
910
* Extracts the tensors of the image regions containing the detected faces.
10-
* Returned tensors have to be disposed manually once you don't need them anymore!
11-
* Useful if you want to compute the face descriptors for the face
12-
* images. Using this method is faster then extracting a canvas for each face and
11+
* Useful if you want to compute the face descriptors for the face images.
12+
* Using this method is faster then extracting a canvas for each face and
1313
* converting them to tensors individually.
1414
*
1515
* @param input The image that face detection has been performed on.
16-
* @param detections The face detection results for that image.
16+
* @param detections The face detection results or face bounding boxes for that image.
1717
* @returns Tensors of the corresponding image region for each detected face.
1818
*/
1919
export function extractFaceTensors(
2020
image: tf.Tensor | NetInput | TNetInput,
21-
detections: FaceDetection[]
21+
detections: Array<FaceDetection|Rect>
2222
): tf.Tensor4D[] {
2323
return tf.tidy(() => {
2424
const imgTensor = getImageTensor(image)
2525

2626
// TODO handle batches
2727
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
2828

29-
const faceTensors = detections.map(det => {
30-
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).getBox().floor()
31-
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
32-
})
29+
const boxes = detections.map(
30+
det => det instanceof FaceDetection
31+
? det.forSize(imgWidth, imgHeight).getBox().floor()
32+
: det
33+
)
34+
const faceTensors = boxes.map(({ x, y, width, height }) =>
35+
tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
36+
)
3337

3438
return faceTensors
3539
})

src/extractFaces.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
import { FaceDetection } from './faceDetectionNet/FaceDetection';
2+
import { Rect } from './Rect';
23
import { createCanvas, getContext2dOrThrow } from './utils';
34

45
/**
56
* Extracts the image regions containing the detected faces.
67
*
78
* @param input The image that face detection has been performed on.
8-
* @param detections The face detection results for that image.
9+
* @param detections The face detection results or face bounding boxes for that image.
910
* @returns The Canvases of the corresponding image region for each detected face.
1011
*/
1112
export function extractFaces(
1213
image: HTMLCanvasElement,
13-
detections: FaceDetection[]
14+
detections: Array<FaceDetection|Rect>
1415
): HTMLCanvasElement[] {
1516
const ctx = getContext2dOrThrow(image)
1617

17-
return detections.map(det => {
18-
const { x, y, width, height } = det.forSize(image.width, image.height).getBox().floor()
18+
const boxes = detections.map(
19+
det => det instanceof FaceDetection
20+
? det.forSize(image.width, image.height).getBox().floor()
21+
: det
22+
)
23+
return boxes.map(({ x, y, width, height }) => {
1924
const faceImg = createCanvas({ width, height })
2025
getContext2dOrThrow(faceImg)
2126
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)

src/faceLandmarkNet/FaceLandmarks.ts

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
import { Point, IPoint } from '../Point';
1+
import { getCenterPoint } from '../commons/getCenterPoint';
2+
import { FaceDetection } from '../faceDetectionNet/FaceDetection';
3+
import { Point } from '../Point';
4+
import { Rect } from '../Rect';
25
import { Dimensions } from '../types';
36

7+
// face alignment constants
8+
const relX = 0.5
9+
const relY = 0.43
10+
const relScale = 0.45
11+
412
export class FaceLandmarks {
5-
private _faceLandmarks: Point[]
613
private _imageWidth: number
714
private _imageHeight: number
815
private _shift: Point
16+
private _faceLandmarks: Point[]
917

1018
constructor(
1119
relativeFaceLandmarkPositions: Point[],
@@ -21,41 +29,53 @@ export class FaceLandmarks {
2129
)
2230
}
2331

24-
public getPositions() {
32+
public getShift(): Point {
33+
return new Point(this._shift.x, this._shift.y)
34+
}
35+
36+
public getImageWidth(): number {
37+
return this._imageWidth
38+
}
39+
40+
public getImageHeight(): number {
41+
return this._imageHeight
42+
}
43+
44+
public getPositions(): Point[] {
2545
return this._faceLandmarks
2646
}
2747

28-
public getRelativePositions() {
48+
public getRelativePositions(): Point[] {
2949
return this._faceLandmarks.map(
3050
pt => pt.sub(this._shift).div(new Point(this._imageWidth, this._imageHeight))
3151
)
3252
}
3353

34-
public getJawOutline() {
54+
public getJawOutline(): Point[] {
3555
return this._faceLandmarks.slice(0, 17)
3656
}
3757

38-
public getLeftEyeBrow() {
58+
public getLeftEyeBrow(): Point[] {
3959
return this._faceLandmarks.slice(17, 22)
4060
}
4161

42-
public getRightEyeBrow() {
62+
public getRightEyeBrow(): Point[] {
4363
return this._faceLandmarks.slice(22, 27)
4464
}
4565

46-
public getNose() {
66+
public getNose(): Point[] {
4767
return this._faceLandmarks.slice(27, 36)
4868
}
4969

50-
public getLeftEye() {
70+
public getLeftEye(): Point[] {
5171
return this._faceLandmarks.slice(36, 42)
5272
}
5373

54-
public getRightEye() {
74+
public getRightEye(): Point[] {
5575
return this._faceLandmarks.slice(42, 48)
5676
}
5777

58-
public getMouth() {
78+
public getMouth(): Point[] {
5979
return this._faceLandmarks.slice(48, 68)
6080
}
6181

@@ -73,4 +93,46 @@ export class FaceLandmarks {
7393
new Point(x, y)
7494
)
7595
}
96+
97+
/**
98+
* Aligns the face landmarks after face detection from the relative positions of the faces
99+
* bounding box, or it's current shift. This function should be used to align the face images
100+
* after face detection has been performed, before they are passed to the face recognition net.
101+
* This will make the computed face descriptor more accurate.
102+
*
103+
* @param detection (optional) The bounding box of the face or the face detection result. If
104+
* no argument was passed the position of the face landmarks are assumed to be relative to
105+
* it's current shift.
106+
* @returns The bounding box of the aligned face.
107+
*/
108+
public align(
109+
detection?: Rect
110+
): Rect {
111+
if (detection) {
112+
const box = detection instanceof FaceDetection
113+
? detection.getBox().floor()
114+
: detection
115+
116+
return this.shift(box.x, box.y).align()
117+
}
118+
119+
const centers = [
120+
this.getLeftEye(),
121+
this.getRightEye(),
122+
this.getMouth()
123+
].map(getCenterPoint)
124+
125+
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
126+
const distToMouth = (pt: Point) => mouthCenter.sub(pt).magnitude()
127+
const eyeToMouthDist = (distToMouth(leftEyeCenter) + distToMouth(rightEyeCenter)) / 2
128+
129+
const size = Math.floor(eyeToMouthDist / relScale)
130+
131+
const refPoint = getCenterPoint(centers)
132+
// TODO: pad in case rectangle is out of image bounds
133+
const x = Math.floor(Math.max(0, refPoint.x - (relX * size)))
134+
const y = Math.floor(Math.max(0, refPoint.y - (relY * size)))
135+
136+
return new Rect(x, y, size, size)
137+
}
76138
}

src/padToSquare.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import * as tf from '@tensorflow/tfjs-core';
44
* Pads the smaller dimension of an image tensor with zeros, such that width === height.
55
*
66
* @param imgTensor The image tensor.
7-
* @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image
7+
* @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image.
88
* @returns The padded tensor with width === height.
99
*/
1010
export function padToSquare(

src/utils.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
13
import { FaceDetection } from './faceDetectionNet/FaceDetection';
24
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
3-
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
45
import { Point } from './Point';
6+
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
57

68
export function isFloat(num: number) {
79
return num % 1 !== 0
@@ -68,6 +70,18 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
6870
})
6971
}
7072

73+
export async function imageTensorToCanvas(
74+
imgTensor: tf.Tensor4D,
75+
canvas?: HTMLCanvasElement
76+
): Promise<HTMLCanvasElement> {
77+
const targetCanvas = canvas || document.createElement('canvas')
78+
79+
const [_, height, width, numChannels] = imgTensor.shape
80+
await tf.toPixels(imgTensor.as3D(height, width, numChannels).toInt(), targetCanvas)
81+
82+
return targetCanvas
83+
}
84+
7185
export function getDefaultDrawOptions(): DrawOptions {
7286
return {
7387
color: 'blue',

0 commit comments

Comments
 (0)