Skip to content

Commit 08aae43

Browse files
runnin nets sequentially instead of in batches seems to be faster + gather runtime stats
1 parent e04770c commit 08aae43

File tree

7 files changed

+169
-60
lines changed

7 files changed

+169
-60
lines changed

src/mtcnn/Mtcnn.ts

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ import { TNetInput } from '../types';
1010
import { bgrToRgbTensor } from './bgrToRgbTensor';
1111
import { extractParams } from './extractParams';
1212
import { FaceLandmarks5 } from './FaceLandmarks5';
13+
import { getSizesForScale } from './getSizesForScale';
1314
import { pyramidDown } from './pyramidDown';
1415
import { stage1 } from './stage1';
1516
import { stage2 } from './stage2';
1617
import { stage3 } from './stage3';
17-
import { NetParams } from './types';
18+
import { MtcnnResult, NetParams } from './types';
1819

1920
export class Mtcnn extends NeuralNetwork<NetParams> {
2021

@@ -26,8 +27,9 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
2627
input: NetInput,
2728
minFaceSize: number = 20,
2829
scaleFactor: number = 0.709,
30+
maxNumScales: number = 10,
2931
scoreThresholds: number[] = [0.6, 0.7, 0.7]
30-
): Promise<any> {
32+
): Promise<{ results: MtcnnResult[], stats: any }> {
3133

3234
const { params } = this
3335

@@ -42,6 +44,10 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
4244
throw new Error('Mtcnn - inputCanvas is not defined, note that passing tensors into Mtcnn.forwardInput is not supported yet.')
4345
}
4446

47+
const stats: any = {}
48+
49+
const tsTotal = Date.now()
50+
4551
const imgTensor = tf.tidy(() =>
4652
bgrToRgbTensor(
4753
tf.expandDims(inputTensor).toFloat() as tf.Tensor4D
@@ -51,18 +57,47 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
5157
const [height, width] = imgTensor.shape.slice(1)
5258

5359
const scales = pyramidDown(minFaceSize, scaleFactor, [height, width])
54-
const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet)
60+
.filter(scale => {
61+
const sizes = getSizesForScale(scale, [height, width])
62+
return Math.min(sizes.width, sizes.height) > 48
63+
})
64+
.slice(0, maxNumScales)
65+
66+
stats.scales = scales
67+
stats.pyramid = scales.map(scale => getSizesForScale(scale, [height, width]))
68+
69+
let ts = Date.now()
70+
const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet, stats)
71+
stats.total_stage1 = Date.now() - ts
72+
73+
if (!out1.boxes.length) {
74+
stats.total = Date.now() - tsTotal
75+
return { results: [], stats }
76+
}
5577

78+
stats.stage2_numInputBoxes = out1.boxes.length
5679
// using the inputCanvas to extract and resize the image patches, since it is faster
5780
// than doing this on the gpu
58-
const out2 = await stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet)
59-
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet)
81+
ts = Date.now()
82+
const out2 = await stage2(inputCanvas, out1.boxes, scoreThresholds[1], params.rnet, stats)
83+
stats.total_stage2 = Date.now() - ts
84+
85+
if (!out2.boxes.length) {
86+
stats.total = Date.now() - tsTotal
87+
return { results: [], stats }
88+
}
89+
90+
stats.stage3_numInputBoxes = out2.boxes.length
91+
92+
ts = Date.now()
93+
const out3 = await stage3(inputCanvas, out2.boxes, scoreThresholds[2], params.onet, stats)
94+
stats.total_stage3 = Date.now() - ts
6095

6196
imgTensor.dispose()
6297
input.dispose()
6398

64-
const faceDetections = out3.boxes.map((box, idx) =>
65-
new FaceDetection(
99+
const results = out3.boxes.map((box, idx) => ({
100+
faceDetection: new FaceDetection(
66101
out3.scores[idx],
67102
new Rect(
68103
box.left / width,
@@ -74,32 +109,47 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
74109
height,
75110
width
76111
}
77-
)
78-
)
79-
80-
const faceLandmarks = out3.points.map(pts =>
81-
new FaceLandmarks5(
82-
pts.map(pt => pt.div(new Point(width, height))),
112+
),
113+
faceLandmarks: new FaceLandmarks5(
114+
out3.points[idx].map(pt => pt.div(new Point(width, height))),
83115
{ width, height }
84116
)
85-
)
117+
}))
86118

87-
return {
88-
faceDetections,
89-
faceLandmarks
90-
}
119+
stats.total = Date.now() - tsTotal
120+
return { results, stats }
91121
}
92122

93123
public async forward(
94124
input: TNetInput,
95125
minFaceSize: number = 20,
96126
scaleFactor: number = 0.709,
127+
maxNumScales: number = 10,
128+
scoreThresholds: number[] = [0.6, 0.7, 0.7]
129+
): Promise<MtcnnResult[]> {
130+
return (
131+
await this.forwardInput(
132+
await toNetInput(input, true, true),
133+
minFaceSize,
134+
scaleFactor,
135+
maxNumScales,
136+
scoreThresholds
137+
)
138+
).results
139+
}
140+
141+
public async forwardWithStats(
142+
input: TNetInput,
143+
minFaceSize: number = 20,
144+
scaleFactor: number = 0.709,
145+
maxNumScales: number = 10,
97146
scoreThresholds: number[] = [0.6, 0.7, 0.7]
98-
): Promise<tf.Tensor2D> {
147+
): Promise<{ results: MtcnnResult[], stats: any }> {
99148
return this.forwardInput(
100149
await toNetInput(input, true, true),
101150
minFaceSize,
102151
scaleFactor,
152+
maxNumScales,
103153
scoreThresholds
104154
)
105155
}

src/mtcnn/extractImagePatches.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export async function extractImagePatches(
1010
img: HTMLCanvasElement,
1111
boxes: BoundingBox[],
1212
{ width, height }: Dimensions
13-
): Promise<tf.Tensor4D> {
13+
): Promise<tf.Tensor4D[]> {
1414

1515

1616
const imgCtx = getContext2dOrThrow(img)
@@ -26,26 +26,32 @@ export async function extractImagePatches(
2626
return createImageBitmap(imgData)
2727
}))
2828

29-
const imagePatchesData: number[] = []
29+
const imagePatchesDatas: number[][] = []
3030

3131
bitmaps.forEach(bmp => {
3232
const patch = createCanvas({ width, height })
3333
const patchCtx = getContext2dOrThrow(patch)
3434
patchCtx.drawImage(bmp, 0, 0, width, height)
3535
const { data } = patchCtx.getImageData(0, 0, width, height)
3636

37+
const currData = []
3738
for(let i = 0; i < data.length; i++) {
3839
if ((i + 1) % 4 === 0) continue
39-
imagePatchesData.push(data[i])
40+
currData.push(data[i])
4041
}
42+
imagePatchesDatas.push(currData)
4143
})
4244

43-
return tf.tidy(() => {
44-
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
45-
tf.tensor4d(imagePatchesData, [boxes.length, width, height, 3]),
46-
[0, 2, 1, 3]
47-
).toFloat()) as tf.Tensor4D
4845

49-
return normalize(imagePatchTensor)
46+
return imagePatchesDatas.map(data => {
47+
const t = tf.tidy(() => {
48+
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
49+
tf.tensor4d(data, [1, width, height, 3]),
50+
[0, 2, 1, 3]
51+
).toFloat()) as tf.Tensor4D
52+
53+
return normalize(imagePatchTensor)
54+
})
55+
return t
5056
})
5157
}

src/mtcnn/getSizesForScale.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export function getSizesForScale(scale: number, [height, width]: number[]) {
2+
return {
3+
height: Math.floor(height * scale),
4+
width: Math.floor(width * scale)
5+
}
6+
}

src/mtcnn/stage1.ts

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import { nms } from './nms';
77
import { normalize } from './normalize';
88
import { PNet } from './PNet';
99
import { PNetParams } from './types';
10+
import { getSizesForScale } from './getSizesForScale';
1011

1112
function rescaleAndNormalize(x: tf.Tensor4D, scale: number): tf.Tensor4D {
1213
return tf.tidy(() => {
1314

14-
const [height, width] = x.shape.slice(1)
15-
const resized = tf.image.resizeBilinear(x, [Math.floor(height * scale), Math.floor(width * scale)])
15+
const { height, width } = getSizesForScale(scale, x.shape.slice(1))
16+
const resized = tf.image.resizeBilinear(x, [height, width])
1617
const normalized = normalize(resized)
1718

1819
return (tf.transpose(normalized, [0, 2, 1, 3]) as tf.Tensor4D)
@@ -67,17 +68,20 @@ export function stage1(
6768
imgTensor: tf.Tensor4D,
6869
scales: number[],
6970
scoreThreshold: number,
70-
params: PNetParams
71+
params: PNetParams,
72+
stats: any
7173
) {
74+
stats.stage1 = []
7275

73-
const boxesForScale = scales.map((scale, i) => {
76+
const boxesForScale = scales.map((scale) => {
77+
const statsForScale: any = { scale }
7478

7579
const { scoresTensor, regionsTensor } = tf.tidy(() => {
7680
const resized = rescaleAndNormalize(imgTensor, scale)
7781

78-
82+
let ts = Date.now()
7983
const { prob, regions } = PNet(resized, params)
80-
84+
statsForScale.pnet = Date.now() - ts
8185

8286
const scoresTensor = tf.unstack(tf.unstack(prob, 3)[1])[0] as tf.Tensor2D
8387
const regionsTensor = tf.unstack(regions)[0] as tf.Tensor3D
@@ -99,15 +103,20 @@ export function stage1(
99103
regionsTensor.dispose()
100104

101105
if (!boundingBoxes.length) {
106+
stats.stage1.push(statsForScale)
102107
return []
103108
}
104109

110+
let ts = Date.now()
105111
const indices = nms(
106112
boundingBoxes.map(bbox => bbox.cell),
107113
boundingBoxes.map(bbox => bbox.score),
108114
0.5
109115
)
116+
statsForScale.nms = Date.now() - ts
117+
statsForScale.numBoxes = indices.length
110118

119+
stats.stage1.push(statsForScale)
111120
return indices.map(boxIdx => boundingBoxes[boxIdx])
112121
})
113122

@@ -119,11 +128,13 @@ export function stage1(
119128
let finalScores: number[] = []
120129

121130
if (allBoxes.length > 0) {
131+
let ts = Date.now()
122132
const indices = nms(
123133
allBoxes.map(bbox => bbox.cell),
124134
allBoxes.map(bbox => bbox.score),
125135
0.7
126136
)
137+
stats.stage1_nms = Date.now() - ts
127138

128139
finalScores = indices.map(idx => allBoxes[idx].score)
129140
finalBoxes = indices

src/mtcnn/stage2.ts

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,26 @@ export async function stage2(
88
img: HTMLCanvasElement,
99
inputBoxes: BoundingBox[],
1010
scoreThreshold: number,
11-
params: RNetParams
11+
params: RNetParams,
12+
stats: any
1213
) {
1314

14-
const rnetInput = await extractImagePatches(img, inputBoxes, { width: 24, height: 24 })
15-
const rnetOut = RNet(rnetInput, params)
15+
let ts = Date.now()
16+
const rnetInputs = await extractImagePatches(img, inputBoxes, { width: 24, height: 24 })
17+
stats.stage2_extractImagePatches = Date.now() - ts
1618

17-
rnetInput.dispose()
19+
ts = Date.now()
20+
const rnetOuts = rnetInputs.map(
21+
rnetInput => {
22+
const out = RNet(rnetInput, params)
23+
rnetInput.dispose()
24+
return out
25+
}
26+
)
27+
stats.stage2_rnet = Date.now() - ts
1828

19-
const scores = Array.from(await rnetOut.scores.data())
29+
const scoreDatas = await Promise.all(rnetOuts.map(out => out.scores.data()))
30+
const scores = scoreDatas.map(arr => Array.from(arr)).reduce((all, arr) => all.concat(arr))
2031
const indices = scores
2132
.map((score, idx) => ({ score, idx }))
2233
.filter(c => c.score > scoreThreshold)
@@ -29,27 +40,31 @@ export async function stage2(
2940
let finalScores: number[] = []
3041

3142
if (filteredBoxes.length > 0) {
43+
ts = Date.now()
3244
const indicesNms = nms(
3345
filteredBoxes,
3446
filteredScores,
3547
0.7
3648
)
49+
stats.stage2_nms = Date.now() - ts
3750

3851
const regions = indicesNms.map(idx =>
3952
new BoundingBox(
40-
rnetOut.regions.get(indices[idx], 0),
41-
rnetOut.regions.get(indices[idx], 1),
42-
rnetOut.regions.get(indices[idx], 2),
43-
rnetOut.regions.get(indices[idx], 3)
53+
rnetOuts[indices[idx]].regions.get(0, 0),
54+
rnetOuts[indices[idx]].regions.get(0, 1),
55+
rnetOuts[indices[idx]].regions.get(0, 2),
56+
rnetOuts[indices[idx]].regions.get(0, 3)
4457
)
4558
)
4659

4760
finalScores = indicesNms.map(idx => filteredScores[idx])
4861
finalBoxes = indicesNms.map((idx, i) => filteredBoxes[idx].calibrate(regions[i]))
4962
}
5063

51-
rnetOut.regions.dispose()
52-
rnetOut.scores.dispose()
64+
rnetOuts.forEach(t => {
65+
t.regions.dispose()
66+
t.scores.dispose()
67+
})
5368

5469
return {
5570
boxes: finalBoxes,

0 commit comments

Comments
 (0)