Skip to content

Commit 174e0c2

Browse files
FaceMatcher
1 parent 397c05a commit 174e0c2

File tree

11 files changed

+143
-59
lines changed

11 files changed

+143
-59
lines changed

examples/public/js/bbt.js

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,40 +30,24 @@ function renderFaceImageSelectList(selectListId, onChange, initialValue) {
3030
}
3131

3232
// fetch first image of each class and compute their descriptors
33-
async function initBbtFaceDescriptors(net, numImagesForTraining = 1) {
33+
async function createBbtFaceMatcher(numImagesForTraining = 1) {
3434
const maxAvailableImagesPerClass = 5
3535
numImagesForTraining = Math.min(numImagesForTraining, maxAvailableImagesPerClass)
36-
return Promise.all(classes.map(
36+
37+
const labeledFaceDescriptors = await Promise.all(classes.map(
3738
async className => {
3839
const descriptors = []
3940
for (let i = 1; i < (numImagesForTraining + 1); i++) {
4041
const img = await faceapi.fetchImage(getFaceImageUri(className, i))
4142
descriptors.push(await faceapi.computeFaceDescriptor(img))
4243
}
43-
return {
44-
descriptors,
45-
className
46-
}
47-
}
48-
))
49-
}
5044

51-
function getBestMatch(descriptorsByClass, queryDescriptor) {
52-
function computeMeanDistance(descriptorsOfClass) {
53-
return faceapi.round(
54-
descriptorsOfClass
55-
.map(d => faceapi.euclideanDistance(d, queryDescriptor))
56-
.reduce((d1, d2) => d1 + d2, 0)
57-
/ (descriptorsOfClass.length || 1)
45+
return new faceapi.LabeledFaceDescriptors(
46+
className,
47+
descriptors
5848
)
59-
}
60-
return descriptorsByClass
61-
.map(
62-
({ descriptors, className }) => ({
63-
distance: computeMeanDistance(descriptors),
64-
className
65-
})
66-
)
67-
.reduce((best, curr) => best.distance < curr.distance ? best : curr)
68-
}
49+
}
50+
))
6951

52+
return new faceapi.FaceMatcher(labeledFaceDescriptors)
53+
}

examples/views/batchFaceRecognition.html

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848

4949
<script>
5050
let images = []
51-
let referenceDescriptorsByClass = []
52-
let descriptorsByFace = []
51+
let faceMatcher = null
5352
let numImages = 16
5453
let maxDistance = 0.6
5554

@@ -68,15 +67,12 @@
6867
const canvas = faceapi.createCanvasFromMedia(img)
6968
$('#faceContainer').append(canvas)
7069

71-
const bestMatch = getBestMatch(referenceDescriptorsByClass, descriptor)
72-
const text = `${bestMatch.distance < maxDistance ? bestMatch.className : 'unkown'} (${bestMatch.distance})`
73-
7470
const x = 20, y = canvas.height - 20
7571
faceapi.drawText(
7672
canvas.getContext('2d'),
7773
x,
7874
y,
79-
text,
75+
faceMatcher.findBestMatch(descriptor).toString(),
8076
Object.assign(faceapi.getDefaultDrawOptions(), { color: 'red', fontSize: 16 })
8177
)
8278
}
@@ -106,7 +102,7 @@
106102

107103
async function run() {
108104
await faceapi.loadFaceRecognitionModel('/')
109-
referenceDescriptorsByClass = await initBbtFaceDescriptors(faceapi.recognitionNet, 1)
105+
faceMatcher = await createBbtFaceMatcher(1)
110106
$('#loader').hide()
111107

112108
const imgUris = classes

examples/views/bbtFaceRecognition.html

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,10 @@
6363
</div>
6464

6565
<script>
66-
// for 150 x 150 sized face images 0.6 is a good threshold to
67-
// judge whether two face descriptors are similar or not
68-
const threshold = 0.6
6966
let interval = 2000
7067

7168
let isStop = false
72-
let referenceDescriptorsByClass = []
69+
let faceMatcher = null
7370
let currImageIdx = 2, currClassIdx = 0
7471
let to = null
7572

@@ -116,8 +113,8 @@
116113
const descriptor = await faceapi.computeFaceDescriptor(input)
117114
displayTimeStats(Date.now() - ts)
118115

119-
const bestMatch = getBestMatch(referenceDescriptorsByClass, descriptor)
120-
$('#prediction').val(`${bestMatch.distance < threshold ? bestMatch.className : 'unknown'} (${bestMatch.distance})`)
116+
const bestMatch = faceMatcher.findBestMatch(descriptor)
117+
$('#prediction').val(bestMatch.toString())
121118

122119
currImageIdx = currClassIdx === (classes.length - 1)
123120
? currImageIdx + 1
@@ -138,7 +135,7 @@
138135

139136
setStatusText('computing initial descriptors...')
140137

141-
referenceDescriptorsByClass = await initBbtFaceDescriptors(faceapi.recognitionNet)
138+
faceMatcher = await createBbtFaceMatcher(1)
142139
$('#loader').hide()
143140

144141
runFaceRecognition()

examples/views/faceRecognition.html

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@
139139
</body>
140140

141141
<script>
142-
const maxDescriptorDistance = 0.6
143-
let referenceDescriptorsByClass = []
142+
let faceMatcher = null
144143

145144
async function updateResults() {
146145
if (!isFaceDetectionModelLoaded()) {
@@ -159,20 +158,17 @@
159158
}
160159

161160
function drawFaceRecognitionResults(results) {
162-
const { width, height } = $('#inputImg').get(0)
163161
const canvas = $('#overlay').get(0)
164-
canvas.width = width
165-
canvas.height = height
166-
167162
// resize detection and landmarks in case displayed image is smaller than
168163
// original size
169-
results = results.map(res => res.forSize(width, height))
170-
171-
const boxesWithText = results.map(({ detection, descriptor }) => {
172-
const bestMatch = getBestMatch(referenceDescriptorsByClass, descriptor)
173-
const text = `${bestMatch.distance < maxDescriptorDistance ? bestMatch.className : 'unknown'} (${bestMatch.distance})`
174-
return new faceapi.BoxWithText(detection.box, text)
175-
})
164+
resizedResults = resizeCanvasAndResults($('#inputImg').get(0), canvas, results)
165+
166+
const boxesWithText = resizedResults.map(({ detection, descriptor }) =>
167+
new faceapi.BoxWithText(
168+
detection.box,
169+
faceMatcher.findBestMatch(descriptor).toString()
170+
)
171+
)
176172
faceapi.drawDetection(canvas, boxesWithText)
177173
}
178174

@@ -182,8 +178,8 @@
182178
await faceapi.loadFaceLandmarkModel('/')
183179
await faceapi.loadFaceRecognitionModel('/')
184180

185-
// initialize reference descriptors (1 per bbt character)
186-
referenceDescriptorsByClass = await initBbtFaceDescriptors(faceapi.recognitionNet, 1)
181+
// initialize face matcher with 1 reference descriptor per bbt character
182+
faceMatcher = await createBbtFaceMatcher(1)
187183

188184
// start processing image
189185
updateResults()

src/classes/FaceMatch.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { round } from 'tfjs-image-recognition-base';
2+
3+
export class FaceMatch {
4+
private _label: string
5+
private _distance: number
6+
7+
constructor(label: string, distance: number) {
8+
this._label = label
9+
this._distance = distance
10+
}
11+
12+
public get label(): string { return this._label }
13+
public get distance(): number { return this._distance }
14+
15+
public toString(withDistance: boolean = true): string {
16+
return `${this.label}${withDistance ? ` (${round(this.distance)})` : ''}`
17+
}
18+
}

src/classes/FullFaceDescription.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { FaceLandmarks } from './FaceLandmarks';
44
import { FaceLandmarks68 } from './FaceLandmarks68';
55

66
export interface IFullFaceDescription<TFaceLandmarks extends FaceLandmarks = FaceLandmarks68>
7-
extends IFaceDetectionWithLandmarks<TFaceLandmarks>{
7+
extends IFaceDetectionWithLandmarks<TFaceLandmarks> {
88

99
detection: FaceDetection,
1010
landmarks: TFaceLandmarks,

src/classes/LabeledFaceDescriptors.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
export class LabeledFaceDescriptors {
2+
private _label: string
3+
private _descriptors: Float32Array[]
4+
5+
constructor(label: string, descriptors: Float32Array[]) {
6+
if (!(typeof label === 'string')) {
7+
throw new Error('LabeledFaceDescriptors - constructor expected label to be a string')
8+
}
9+
10+
if (!Array.isArray(descriptors) || descriptors.some(desc => !(desc instanceof Float32Array))) {
11+
throw new Error('LabeledFaceDescriptors - constructor expected descriptors to be an array of Float32Array')
12+
}
13+
14+
this._label = label
15+
this._descriptors = descriptors
16+
}
17+
18+
public get label(): string { return this._label }
19+
public get descriptors(): Float32Array[] { return this._descriptors }
20+
}

src/classes/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@ export * from './FaceDetection';
22
export * from './FaceLandmarks';
33
export * from './FaceLandmarks5';
44
export * from './FaceLandmarks68';
5-
export * from './FullFaceDescription';
5+
export * from './FaceMatch';
6+
export * from './FullFaceDescription';
7+
export * from './LabeledFaceDescriptors';

src/faceLandmarkNet/FaceLandmark68Net.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ function denseBlock(
3838
export class FaceLandmark68Net extends FaceLandmark68NetBase<NetParams> {
3939

4040
constructor() {
41-
super('FaceLandmark68LargeNet')
41+
super('FaceLandmark68Net')
4242
}
4343

4444
public runNet(input: NetInput): tf.Tensor2D {
4545

4646
const { params } = this
4747

4848
if (!params) {
49-
throw new Error('FaceLandmark68LargeNet - load model before inference')
49+
throw new Error('FaceLandmark68Net - load model before inference')
5050
}
5151

5252
return tf.tidy(() => {

src/globalApi/FaceMatcher.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import { FaceMatch } from '../classes/FaceMatch';
2+
import { FullFaceDescription } from '../classes/FullFaceDescription';
3+
import { LabeledFaceDescriptors } from '../classes/LabeledFaceDescriptors';
4+
import { euclideanDistance } from '../euclideanDistance';
5+
6+
export class FaceMatcher {
7+
8+
private _labeledDescriptors: LabeledFaceDescriptors[]
9+
private _distanceThreshold: number
10+
11+
constructor(
12+
inputs: LabeledFaceDescriptors | FullFaceDescription | Float32Array | Array<LabeledFaceDescriptors | FullFaceDescription | Float32Array>,
13+
distanceThreshold: number = 0.6
14+
) {
15+
16+
this._distanceThreshold = distanceThreshold
17+
18+
const inputArray = Array.isArray(inputs) ? inputs : [inputs]
19+
20+
if (!inputArray.length) {
21+
throw new Error(`FaceRecognizer.constructor - expected atleast one input`)
22+
}
23+
24+
let count = 1
25+
const createUniqueLabel = () => `person ${count++}`
26+
27+
this._labeledDescriptors = inputArray.map((desc) => {
28+
if (desc instanceof LabeledFaceDescriptors) {
29+
return desc
30+
}
31+
32+
if (desc instanceof FullFaceDescription) {
33+
return new LabeledFaceDescriptors(createUniqueLabel(), [desc.descriptor])
34+
}
35+
36+
if (desc instanceof Float32Array) {
37+
return new LabeledFaceDescriptors(createUniqueLabel(), [desc])
38+
}
39+
40+
throw new Error(`FaceRecognizer.constructor - expected inputs to be of type LabeledFaceDescriptors | FullFaceDescription | Float32Array | Array<LabeledFaceDescriptors | FullFaceDescription | Float32Array>`)
41+
})
42+
}
43+
44+
public get labeledDescriptors(): LabeledFaceDescriptors[] { return this._labeledDescriptors }
45+
public get distanceThreshold(): number { return this._distanceThreshold }
46+
47+
public computeMeanDistance(queryDescriptor: Float32Array, descriptors: Float32Array[]): number {
48+
return descriptors
49+
.map(d => euclideanDistance(d, queryDescriptor))
50+
.reduce((d1, d2) => d1 + d2, 0)
51+
/ (descriptors.length || 1)
52+
}
53+
54+
public matchDescriptor(queryDescriptor: Float32Array): FaceMatch {
55+
return this.labeledDescriptors
56+
.map(({ descriptors, label }) => new FaceMatch(
57+
label,
58+
this.computeMeanDistance(queryDescriptor, descriptors)
59+
))
60+
.reduce((best, curr) => best.distance < curr.distance ? best : curr)
61+
}
62+
63+
public findBestMatch(queryDescriptor: Float32Array): FaceMatch {
64+
const bestMatch = this.matchDescriptor(queryDescriptor)
65+
return bestMatch.distance < this.distanceThreshold
66+
? bestMatch
67+
: new FaceMatch('unknown', bestMatch.distance)
68+
}
69+
70+
}

src/globalApi/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export * from './ComposableTask'
33
export * from './ComputeFaceDescriptorsTasks'
44
export * from './DetectFacesTasks'
55
export * from './DetectFaceLandmarksTasks'
6+
export * from './FaceMatcher'
67
export * from './nets'
78
export * from './types'
89

0 commit comments

Comments
 (0)