Skip to content

Commit 1e2d261

Browse files
face recognition net now accepts batch inputs
1 parent 8ef9b66 commit 1e2d261

File tree

9 files changed

+171
-24
lines changed

9 files changed

+171
-24
lines changed

src/allFacesFactory.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export function allFacesFactory(
3535

3636
const descriptors = await Promise.all(alignedFaceTensors.map(
3737
faceTensor => recognitionNet.computeFaceDescriptor(faceTensor)
38-
))
38+
)) as Float32Array[]
3939
alignedFaceTensors.forEach(t => t.dispose())
4040

4141
return detections.map((detection, i) =>

src/faceRecognitionNet/FaceRecognitionNet.ts

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ export class FaceRecognitionNet {
3030
this._params = extractParams(weights)
3131
}
3232

33-
public async forwardInput(input: NetInput): Promise<tf.Tensor2D> {
33+
public forwardInput(input: NetInput): tf.Tensor2D {
3434
if (!this._params) {
3535
throw new Error('FaceRecognitionNet - load model before inference')
3636
}
3737

38-
3938
return tf.tidy(() => {
4039
const batchTensor = input.toBatchTensor(150, true)
4140

@@ -68,14 +67,26 @@ export class FaceRecognitionNet {
6867
return fullyConnected
6968
})
7069
}
70+
7171
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
7272
return this.forwardInput(await toNetInput(input, true))
7373
}
7474

75-
public async computeFaceDescriptor(input: TNetInput) {
76-
const result = await this.forward(await toNetInput(input, true))
77-
const data = await result.data()
78-
result.dispose()
79-
return data as Float32Array
75+
public async computeFaceDescriptor(input: TNetInput): Promise<Float32Array|Float32Array[]> {
76+
const netInput = await toNetInput(input, true)
77+
78+
const faceDescriptorTensors = tf.tidy(
79+
() => tf.unstack(this.forwardInput(netInput))
80+
)
81+
82+
const faceDescriptorsForBatch = await Promise.all(faceDescriptorTensors.map(
83+
t => t.data()
84+
)) as Float32Array[]
85+
86+
faceDescriptorTensors.forEach(t => t.dispose())
87+
88+
return netInput.isBatchInput
89+
? faceDescriptorsForBatch
90+
: faceDescriptorsForBatch[0]
8091
}
8192
}

src/faceRecognitionNet/normalize.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
export function normalize(x: tf.Tensor4D): tf.Tensor4D {
44
return tf.tidy(() => {
5-
const avg_r = tf.fill([1, 150, 150, 1], 122.782);
6-
const avg_g = tf.fill([1, 150, 150, 1], 117.001);
7-
const avg_b = tf.fill([1, 150, 150, 1], 104.298);
5+
const avg_r = tf.fill([...x.shape.slice(0, 3), 1], 122.782);
6+
const avg_g = tf.fill([...x.shape.slice(0, 3), 1], 117.001);
7+
const avg_b = tf.fill([...x.shape.slice(0, 3), 1], 104.298);
88
const avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3)
99

1010
return tf.div(tf.sub(x, avg_rgb), tf.scalar(256))

src/globalApi.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export function detectLandmarks(
5050

5151
export function computeFaceDescriptor(
5252
input: TNetInput
53-
): Promise<Float32Array> {
53+
): Promise<Float32Array | Float32Array[]> {
5454
return recognitionNet.computeFaceDescriptor(input)
5555
}
5656

File renamed without changes.

test/data/faceDescriptor2.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[-0.08900658041238785, 0.10903996974229813, 0.027176279574632645, 0.04400758072733879, -0.14542895555496216, 0.11051996797323227, -0.04482650384306908, -0.05154910683631897, 0.10313281416893005, -0.09580713510513306, 0.11335672438144684, -0.02723177894949913, -0.2017219066619873, 0.09402787685394287, -0.025814395397901535, 0.07219463586807251, -0.12272300571203232, -0.07349629700183868, -0.1723618507385254, -0.1745331585407257, -0.03420797362923622, 0.10511981695890427, 0.0262751504778862, 0.014430010691285133, -0.2035353034734726, -0.2949812114238739, -0.04833773523569107, -0.10960741341114044, 0.08448510617017746, -0.039910122752189636, -0.03964325413107872, -0.099286288022995, -0.16025686264038086, 0.026379037648439407, 0.09079921245574951, 0.07745557278394699, -0.05415252223610878, -0.017411116510629654, 0.16053830087184906, 0.010681805200874805, -0.11814302206039429, 0.0382964164018631, 0.08098040521144867, 0.29891595244407654, 0.1258186250925064, 0.06479117274284363, 0.02330329827964306, -0.07838230580091476, 0.1363348364830017, -0.21215586364269257, 0.07675530016422272, 0.1447518914937973, 0.14686468243598938, 0.06991209089756012, 0.08843740075826645, -0.11935211718082428, -0.015284902416169643, 0.16930945217609406, -0.044002968817949295, 0.16501764953136444, 0.10481955111026764, -0.013367846608161926, -0.05079612880945206, -0.07971523702144623, 0.2541899085044861, 0.07128541171550751, -0.1458708792924881, -0.15604135394096375, 0.11365226656198502, -0.16018034517765045, -0.034580036997795105, 0.05678928270936012, -0.07191935181617737, -0.15881866216659546, -0.1955043375492096, 0.06456604599952698, 0.5308966040611267, 0.13605228066444397, -0.18340089917182922, -0.054736778140068054, -0.09668046236038208, -0.0006025233305990696, 0.06609033048152924, 0.0835171788930893, -0.13018545508384705, -0.07167276740074158, -0.04313529655337334, 0.08809386193752289, 0.29993879795074463, -0.07008976489305496, 0.005112136714160442, 0.1464609056711197, 0.03064284473657608, 0.005341261625289917, -0.03758316487073898, -0.002741048112511635, -0.19020092487335205, -0.005203879438340664, -0.03693881630897522, 0.017715569585561752, 0.025151528418064117, -0.1393381506204605, 0.04255775362253189, 0.080945685505867, -0.23745450377464294, 0.21049565076828003, -0.01615971140563488, -0.0642223060131073, 0.0915207713842392, 0.10660708695650101, -0.14731745421886444, -0.027426915243268013, 0.2378913164138794, -0.2964036166667938, 0.2034282684326172, 0.2009482979774475, 0.04706001281738281, 0.13964271545410156, 0.05233509838581085, 0.11507777869701385, 0.045886922627687454, 0.12765641510486603, -0.15917260944843292, -0.13223722577095032, -0.023241272196173668, -0.129884734749794, -0.027176398783922195, 0.009421694092452526]

test/data/faceDescriptorRect.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[-0.13293321430683136, 0.09793781489133835, 0.06550372391939163, 0.02364283800125122, -0.043399304151535034, 0.004586201161146164, -0.09000064432621002, -0.05539097636938095, 0.10467389971017838, -0.09715163707733154, 0.18808841705322266, -0.0205547958612442, -0.23795807361602783, -0.026068881154060364, -0.04790578782558441, 0.10736768692731857, -0.1791372150182724, -0.09754926711320877, -0.08212480694055557, -0.07197146117687225, 0.07512062042951584, 0.06562784314155579, -0.06910805404186249, 0.010537944734096527, -0.1353086233139038, -0.29961100220680237, -0.04597249627113342, -0.09019482880830765, 0.04843198508024216, -0.08456507325172424, -0.06385420262813568, 0.09591938555240631, -0.08721363544464111, 0.0029465071856975555, 0.062499962747097015, 0.08367685973644257, 0.004837760701775551, -0.02126195654273033, 0.18138188123703003, -0.0330311618745327, -0.1149168312549591, -0.014434240758419037, 0.04467501491308212, 0.32643717527389526, 0.13417592644691467, 0.049149081110954285, 0.0002636462450027466, -0.030674105510115623, 0.15085124969482422, -0.25617715716362, 0.007638035342097282, 0.20309507846832275, 0.155135378241539, 0.10535001009702682, 0.09949050843715668, -0.19686023890972137, 0.055761925876140594, 0.10784860700368881, -0.16404221951961517, 0.12705324590206146, 0.06780532747507095, -0.12821750342845917, -0.015174079686403275, -0.08541303128004074, 0.23064906895160675, 0.04403648152947426, -0.16575516760349274, -0.10698974132537842, 0.13322079181671143, -0.10516376793384552, -0.03650324046611786, 0.05603502690792084, -0.1468498408794403, -0.21398313343524933, -0.22947216033935547, 0.022328242659568787, 0.4006509780883789, 0.2338075339794159, -0.1980385184288025, 0.05581464245915413, -0.033158354461193085, -0.047999653965234756, 0.10474226623773575, 0.11267579346895218, -0.0938166156411171, -0.005631402134895325, -0.0698985829949379, 0.06661885231733322, 0.18326956033706665, 0.042940653860569, -0.031386956572532654, 0.2056775838136673, 0.011491281911730766, 0.05759737640619278, -0.029466431587934494, -0.04597870260477066, -0.07393362373113632, -0.037820909172296524, -0.07149908691644669, 0.023783499374985695, 0.016364723443984985, -0.09576655924320221, 0.02455282025039196, 0.11984197050333023, -0.11477060616016388, 0.17211446166038513, -0.008100427687168121, 0.09116753190755844, -0.004660069011151791, 0.029939215630292892, -0.10707360506057739, 0.03878428786993027, 0.15494686365127563, -0.2801153063774109, 0.1764734983444214, 0.1614546924829483, 0.09864784777164459, 0.12133727967739105, 0.05214153230190277, 0.04244184494018555, 0.024142231792211533, -0.019513756036758423, -0.22539466619491577, -0.0927465632557869, 0.06196486949920654, -0.09522707760334015, 0.04965142160654068, 0.023237790912389755]

test/tests/e2e/faceLandmarkNet.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ describe('faceLandmarkNet', () => {
8787
await faceLandmarkNet.load('base/weights')
8888
})
8989

90-
it('computes face landmarks', async () => {
90+
it('computes face landmarks for squared input', async () => {
9191
const { width, height } = imgEl1
9292

9393
const result = await faceLandmarkNet.detectLandmarks(imgEl1) as FaceLandmarks
Lines changed: 145 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,156 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
13
import * as faceapi from '../../../src';
24

35
describe('faceRecognitionNet', () => {
46

5-
let faceRecognitionNet: any, imgEl: HTMLImageElement, faceDescriptor: number[]
7+
let imgEl1: HTMLImageElement
8+
let imgEl2: HTMLImageElement
9+
let imgElRect: HTMLImageElement
10+
let faceDescriptor1: number[]
11+
let faceDescriptor2: number[]
12+
let faceDescriptorRect: number[]
613

714
beforeAll(async () => {
8-
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
9-
const weights = new Float32Array(await res.arrayBuffer())
10-
faceRecognitionNet = faceapi.faceRecognitionNet(weights)
15+
const img1 = await (await fetch('base/test/images/face1.png')).blob()
16+
imgEl1 = await faceapi.bufferToImage(img1)
17+
const img2 = await (await fetch('base/test/images/face2.png')).blob()
18+
imgEl2 = await faceapi.bufferToImage(img2)
19+
const imgRect = await (await fetch('base/test/images/face_rectangular.png')).blob()
20+
imgElRect = await faceapi.bufferToImage(imgRect)
21+
faceDescriptor1 = await (await fetch('base/test/data/faceDescriptor1.json')).json()
22+
faceDescriptor2 = await (await fetch('base/test/data/faceDescriptor2.json')).json()
23+
faceDescriptorRect = await (await fetch('base/test/data/faceDescriptorRect.json')).json()
24+
})
25+
26+
describe('uncompressed weights', () => {
27+
28+
let faceRecognitionNet: faceapi.FaceRecognitionNet
29+
30+
beforeAll(async () => {
31+
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
32+
const weights = new Float32Array(await res.arrayBuffer())
33+
faceRecognitionNet = faceapi.faceRecognitionNet(weights)
34+
})
35+
36+
it('computes face descriptor for squared input', async () => {
37+
const result = await faceRecognitionNet.computeFaceDescriptor(imgEl1) as Float32Array
38+
expect(result.length).toEqual(128)
39+
expect(result).toEqual(new Float32Array(faceDescriptor1))
40+
})
41+
42+
it('computes face descriptor for rectangular input', async () => {
43+
const result = await faceRecognitionNet.computeFaceDescriptor(imgElRect) as Float32Array
44+
expect(result.length).toEqual(128)
45+
expect(result).toEqual(new Float32Array(faceDescriptorRect))
46+
})
1147

12-
const img = await (await fetch('base/test/images/face1.png')).blob()
13-
imgEl = await faceapi.bufferToImage(img)
14-
faceDescriptor = await (await fetch('base/test/data/faceDescriptor.json')).json()
1548
})
1649

17-
it('computes face descriptor', async () => {
18-
const result = await faceRecognitionNet.computeFaceDescriptor(imgEl) as number[]
19-
expect(result.length).toEqual(128)
20-
expect(result).toEqual(new Float32Array(faceDescriptor))
50+
// TODO: figure out why descriptors return NaN in the test cases
51+
/*
52+
describe('quantized weights', () => {
53+
54+
let faceRecognitionNet: faceapi.FaceRecognitionNet
55+
56+
beforeAll(async () => {
57+
faceRecognitionNet = new faceapi.FaceRecognitionNet()
58+
await faceRecognitionNet.load('base/weights')
59+
})
60+
61+
it('computes face descriptor for squared input', async () => {
62+
const result = await faceRecognitionNet.computeFaceDescriptor(imgEl1) as Float32Array
63+
expect(result.length).toEqual(128)
64+
expect(result).toEqual(new Float32Array(faceDescriptor1))
65+
})
66+
67+
it('computes face descriptor for rectangular input', async () => {
68+
const result = await faceRecognitionNet.computeFaceDescriptor(imgElRect) as Float32Array
69+
expect(result.length).toEqual(128)
70+
expect(result).toEqual(new Float32Array(faceDescriptorRect))
71+
})
72+
73+
})
74+
*/
75+
76+
describe('batch inputs', () => {
77+
78+
let faceRecognitionNet: faceapi.FaceRecognitionNet
79+
80+
beforeAll(async () => {
81+
const res = await fetch('base/weights/uncompressed/face_recognition_model.weights')
82+
const weights = new Float32Array(await res.arrayBuffer())
83+
faceRecognitionNet = faceapi.faceRecognitionNet(weights)
84+
})
85+
86+
it('computes face descriptors for batch of image elements', async () => {
87+
const inputs = [imgEl1, imgEl2, imgElRect]
88+
89+
const faceDescriptors = [
90+
faceDescriptor1,
91+
faceDescriptor2,
92+
faceDescriptorRect
93+
]
94+
95+
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
96+
expect(Array.isArray(results)).toBe(true)
97+
expect(results.length).toEqual(3)
98+
results.forEach((result, batchIdx) => {
99+
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
100+
})
101+
})
102+
103+
it('computes face landmarks for batch of tf.Tensor3D', async () => {
104+
const inputs = [imgEl1, imgEl2, imgElRect].map(el => tf.fromPixels(el))
105+
106+
const faceDescriptors = [
107+
faceDescriptor1,
108+
faceDescriptor2,
109+
faceDescriptorRect
110+
]
111+
112+
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
113+
expect(Array.isArray(results)).toBe(true)
114+
expect(results.length).toEqual(3)
115+
results.forEach((result, batchIdx) => {
116+
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
117+
})
118+
})
119+
120+
it('computes face landmarks for tf.Tensor4D', async () => {
121+
const inputs = [imgEl1, imgEl2].map(el => tf.fromPixels(el))
122+
123+
const faceDescriptors = [
124+
faceDescriptor1,
125+
faceDescriptor2,
126+
faceDescriptorRect
127+
]
128+
129+
const results = await faceRecognitionNet.computeFaceDescriptor(tf.stack(inputs) as tf.Tensor4D) as Float32Array[]
130+
expect(Array.isArray(results)).toBe(true)
131+
expect(results.length).toEqual(2)
132+
results.forEach((result, batchIdx) => {
133+
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
134+
})
135+
})
136+
137+
it('computes face landmarks for batch of mixed inputs', async () => {
138+
const inputs = [imgEl1, tf.fromPixels(imgEl2), tf.fromPixels(imgElRect)]
139+
140+
const faceDescriptors = [
141+
faceDescriptor1,
142+
faceDescriptor2,
143+
faceDescriptorRect
144+
]
145+
146+
const results = await faceRecognitionNet.computeFaceDescriptor(inputs) as Float32Array[]
147+
expect(Array.isArray(results)).toBe(true)
148+
expect(results.length).toEqual(3)
149+
results.forEach((result, batchIdx) => {
150+
expect(result).toEqual(new Float32Array(faceDescriptors[batchIdx]))
151+
})
152+
})
153+
21154
})
155+
22156
})

0 commit comments

Comments
 (0)