Skip to content

Commit 4205fc2

Browse files
allow face alignment before classification
1 parent 256ee65 commit 4205fc2

File tree

7 files changed

+89
-33
lines changed

7 files changed

+89
-33
lines changed

src/classes/FaceLandmarks.ts

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { Dimensions, getCenterPoint, IDimensions, Point, Rect } from 'tfjs-image-recognition-base';
1+
import { Box, Dimensions, getCenterPoint, IBoundingBox, IDimensions, IRect, Point, Rect } from 'tfjs-image-recognition-base';
22

3+
import { minBbox } from '../minBbox';
34
import { FaceDetection } from './FaceDetection';
45

56
// face alignment constants
@@ -71,16 +72,28 @@ export class FaceLandmarks implements IFaceLandmarks {
7172
* @returns The bounding box of the aligned face.
7273
*/
7374
public align(
74-
detection?: FaceDetection | Rect
75-
): Rect {
75+
detection?: FaceDetection | IRect | IBoundingBox | null,
76+
options: { useDlibAlignment?: boolean, minBoxPadding?: number } = { }
77+
): Box {
7678
if (detection) {
7779
const box = detection instanceof FaceDetection
7880
? detection.box.floor()
79-
: detection
81+
: new Box(detection)
8082

81-
return this.shiftBy(box.x, box.y).align()
83+
return this.shiftBy(box.x, box.y).align(null, options)
8284
}
8385

86+
const { useDlibAlignment, minBoxPadding } = Object.assign({}, { useDlibAlignment: false, minBoxPadding: 0.2 }, options)
87+
88+
if (useDlibAlignment) {
89+
return this.alignDlib()
90+
}
91+
92+
return this.alignMinBbox(minBoxPadding)
93+
}
94+
95+
private alignDlib(): Box {
96+
8497
const centers = this.getRefPointsForAlignment()
8598

8699
const [leftEyeCenter, rightEyeCenter, mouthCenter] = centers
@@ -97,6 +110,11 @@ export class FaceLandmarks implements IFaceLandmarks {
97110
return new Rect(x, y, Math.min(size, this.imageWidth + x), Math.min(size, this.imageHeight + y))
98111
}
99112

113+
private alignMinBbox(padding: number): Box {
114+
const box = minBbox(this.positions)
115+
return box.pad(box.width * padding, box.height * padding)
116+
}
117+
100118
protected getRefPointsForAlignment(): Point[] {
101119
throw new Error('getRefPointsForAlignment not implemented by base class')
102120
}

src/globalApi/ComputeFaceDescriptorsTasks.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ export class ComputeAllFaceDescriptorsTask<
2525

2626
const parentResults = await this.parentTask
2727

28-
const alignedRects = parentResults.map(({ alignedRect }) => alignedRect)
29-
const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
30-
? await extractFaceTensors(this.input, alignedRects)
31-
: await extractFaces(this.input, alignedRects)
28+
const dlibAlignedRects = parentResults.map(({ landmarks }) => landmarks.align(null, { useDlibAlignment: true }))
29+
const dlibAlignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
30+
? await extractFaceTensors(this.input, dlibAlignedRects)
31+
: await extractFaces(this.input, dlibAlignedRects)
3232

3333
const results = await Promise.all(parentResults.map(async (parentResult, i) => {
34-
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[i]) as Float32Array
34+
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(dlibAlignedFaces[i]) as Float32Array
3535
return extendWithFaceDescriptor<TSource>(parentResult, descriptor)
3636
}))
3737

38-
alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
38+
dlibAlignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())
3939

4040
return results
4141
}
@@ -52,10 +52,10 @@ export class ComputeSingleFaceDescriptorTask<
5252
return
5353
}
5454

55-
const { alignedRect } = parentResult
55+
const dlibAlignedRect = parentResult.landmarks.align(null, { useDlibAlignment: true })
5656
const alignedFaces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
57-
? await extractFaceTensors(this.input, [alignedRect])
58-
: await extractFaces(this.input, [alignedRect])
57+
? await extractFaceTensors(this.input, [dlibAlignedRect])
58+
: await extractFaces(this.input, [dlibAlignedRect])
5959
const descriptor = await nets.faceRecognitionNet.computeFaceDescriptor(alignedFaces[0]) as Float32Array
6060

6161
alignedFaces.forEach(f => f instanceof tf.Tensor && f.dispose())

src/globalApi/DetectFaceLandmarksTasks.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ import { extendWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFac
1010
import { ComposableTask } from './ComposableTask';
1111
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
1212
import { nets } from './nets';
13-
import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionTask } from './PredictFaceExpressionsTask';
13+
import {
14+
PredictAllFaceExpressionsWithFaceAlignmentTask,
15+
PredictSingleFaceExpressionsWithFaceAlignmentTask,
16+
} from './PredictFaceExpressionsTask';
1417

1518
export class DetectFaceLandmarksTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
1619
constructor(
@@ -52,6 +55,10 @@ export class DetectAllFaceLandmarksTask<
5255
)
5356
}
5457

58+
withFaceExpressions(): PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
59+
return new PredictAllFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
60+
}
61+
5562
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
5663
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
5764
}
@@ -80,6 +87,10 @@ export class DetectSingleFaceLandmarksTask<
8087
return extendWithFaceLandmarks<TSource>(parentResult, landmarks)
8188
}
8289

90+
withFaceExpressions(): PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>> {
91+
return new PredictSingleFaceExpressionsWithFaceAlignmentTask<WithFaceLandmarks<TSource>>(this, this.input)
92+
}
93+
8394
withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
8495
return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
8596
}

src/globalApi/DetectFacesTasks.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { TinyFaceDetectorOptions } from '../tinyFaceDetector/TinyFaceDetectorOpt
88
import { ComposableTask } from './ComposableTask';
99
import { DetectAllFaceLandmarksTask, DetectSingleFaceLandmarksTask } from './DetectFaceLandmarksTasks';
1010
import { nets } from './nets';
11-
import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionTask } from './PredictFaceExpressionsTask';
11+
import { PredictAllFaceExpressionsTask, PredictSingleFaceExpressionsTask } from './PredictFaceExpressionsTask';
1212
import { FaceDetectionOptions } from './types';
1313

1414
export class DetectFacesTaskBase<TReturn> extends ComposableTask<TReturn> {
@@ -101,8 +101,8 @@ export class DetectSingleFaceTask extends DetectFacesTaskBase<FaceDetection | un
101101
)
102102
}
103103

104-
withFaceExpressions(): PredictSingleFaceExpressionTask<WithFaceDetection<{}>> {
105-
return new PredictSingleFaceExpressionTask<WithFaceDetection<{}>>(
104+
withFaceExpressions(): PredictSingleFaceExpressionsTask<WithFaceDetection<{}>> {
105+
return new PredictSingleFaceExpressionsTask<WithFaceDetection<{}>>(
106106
this.runAndExtendWithFaceDetection(),
107107
this.input
108108
)

src/globalApi/PredictFaceExpressionsTask.ts

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import { extractFaces, extractFaceTensors } from '../dom';
55
import { FaceExpressions } from '../faceExpressionNet/FaceExpressions';
66
import { WithFaceDetection } from '../factories/WithFaceDetection';
77
import { extendWithFaceExpressions, WithFaceExpressions } from '../factories/WithFaceExpressions';
8+
import { isWithFaceLandmarks, WithFaceLandmarks } from '../factories/WithFaceLandmarks';
89
import { ComposableTask } from './ComposableTask';
9-
import { DetectAllFaceLandmarksTask, DetectSingleFaceLandmarksTask } from './DetectFaceLandmarksTasks';
10+
import { ComputeAllFaceDescriptorsTask, ComputeSingleFaceDescriptorTask } from './ComputeFaceDescriptorsTasks';
1011
import { nets } from './nets';
1112

1213
export class PredictFaceExpressionsTaskBase<TReturn, TParentReturn> extends ComposableTask<TReturn> {
@@ -26,10 +27,12 @@ export class PredictAllFaceExpressionsTask<
2627

2728
const parentResults = await this.parentTask
2829

29-
const detections = parentResults.map(parentResult => parentResult.detection)
30+
const faceBoxes = parentResults.map(
31+
parentResult => isWithFaceLandmarks(parentResult) ? parentResult.alignedRect : parentResult.detection
32+
)
3033
const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
31-
? await extractFaceTensors(this.input, detections)
32-
: await extractFaces(this.input, detections)
34+
? await extractFaceTensors(this.input, faceBoxes)
35+
: await extractFaces(this.input, faceBoxes)
3336

3437
const faceExpressionsByFace = await Promise.all(faces.map(
3538
face => nets.faceExpressionNet.predictExpressions(face)
@@ -41,13 +44,9 @@ export class PredictAllFaceExpressionsTask<
4144
(parentResult, i) => extendWithFaceExpressions<TSource>(parentResult, faceExpressionsByFace[i])
4245
)
4346
}
44-
45-
withFaceLandmarks(): DetectAllFaceLandmarksTask<WithFaceExpressions<TSource>> {
46-
return new DetectAllFaceLandmarksTask(this, this.input, false)
47-
}
4847
}
4948

50-
export class PredictSingleFaceExpressionTask<
49+
export class PredictSingleFaceExpressionsTask<
5150
TSource extends WithFaceDetection<{}>
5251
> extends PredictFaceExpressionsTaskBase<WithFaceExpressions<TSource> | undefined, TSource | undefined> {
5352

@@ -58,19 +57,33 @@ export class PredictSingleFaceExpressionTask<
5857
return
5958
}
6059

61-
const { detection } = parentResult
60+
const faceBox = isWithFaceLandmarks(parentResult) ? parentResult.alignedRect : parentResult.detection
6261
const faces: Array<HTMLCanvasElement | tf.Tensor3D> = this.input instanceof tf.Tensor
63-
? await extractFaceTensors(this.input, [detection])
64-
: await extractFaces(this.input, [detection])
62+
? await extractFaceTensors(this.input, [faceBox])
63+
: await extractFaces(this.input, [faceBox])
6564

6665
const faceExpressions = await nets.faceExpressionNet.predictExpressions(faces[0]) as FaceExpressions
6766

6867
faces.forEach(f => f instanceof tf.Tensor && f.dispose())
6968

7069
return extendWithFaceExpressions(parentResult, faceExpressions)
7170
}
71+
}
72+
73+
export class PredictAllFaceExpressionsWithFaceAlignmentTask<
74+
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
75+
> extends PredictAllFaceExpressionsTask<TSource> {
76+
77+
withFaceDescriptors(): ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>> {
78+
return new ComputeAllFaceDescriptorsTask<WithFaceLandmarks<TSource>>(this, this.input)
79+
}
80+
}
81+
82+
export class PredictSingleFaceExpressionsWithFaceAlignmentTask<
83+
TSource extends WithFaceLandmarks<WithFaceDetection<{}>>
84+
> extends PredictSingleFaceExpressionsTask<TSource> {
7285

73-
withFaceLandmarks(): DetectSingleFaceLandmarksTask<WithFaceExpressions<TSource>> {
74-
return new DetectSingleFaceLandmarksTask(this, this.input, false)
86+
withFaceDescriptor(): ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>> {
87+
return new ComputeSingleFaceDescriptorTask<WithFaceLandmarks<TSource>>(this, this.input)
7588
}
7689
}

src/minBbox.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { BoundingBox, IPoint } from 'tfjs-image-recognition-base';
2+
3+
export function minBbox(pts: IPoint[]): BoundingBox {
4+
const xs = pts.map(pt => pt.x)
5+
const ys = pts.map(pt => pt.y)
6+
const minX = xs.reduce((min, x) => x < min ? x : min, Infinity)
7+
const minY = ys.reduce((min, y) => y < min ? y : min, Infinity)
8+
const maxX = xs.reduce((max, x) => max < x ? x : max, 0)
9+
const maxY = ys.reduce((max, y) => max < y ? y : max, 0)
10+
11+
return new BoundingBox(minX, minY, maxX, maxY)
12+
}

src/xception/TinyXception.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
} from 'tfjs-image-recognition-base';
1111

1212
import { depthwiseSeparableConv } from '../common/depthwiseSeparableConv';
13+
import { bgrToRgbTensor } from '../mtcnn/bgrToRgbTensor';
1314
import { extractParams } from './extractParams';
1415
import { extractParamsFromWeigthMap } from './extractParamsFromWeigthMap';
1516
import { MainBlockParams, ReductionBlockParams, TinyXceptionParams } from './types';
@@ -54,8 +55,9 @@ export class TinyXception extends NeuralNetwork<TinyXceptionParams> {
5455

5556
return tf.tidy(() => {
5657
const batchTensor = input.toBatchTensor(112, true)
58+
const batchTensorRgb = bgrToRgbTensor(batchTensor)
5759
const meanRgb = [122.782, 117.001, 104.298]
58-
const normalized = normalize(batchTensor, meanRgb).div(tf.scalar(255)) as tf.Tensor4D
60+
const normalized = normalize(batchTensorRgb, meanRgb).div(tf.scalar(256)) as tf.Tensor4D
5961

6062
let out = tf.relu(conv(normalized, params.entry_flow.conv_in, [2, 2]))
6163
out = reductionBlock(out, params.entry_flow.reduction_block_0, false)

0 commit comments

Comments
 (0)