Skip to content

Commit e7d1d04

Browse files
concat score tensors before getting their data, apparently it's faster this way
1 parent 0e899bd commit e7d1d04

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/mtcnn/stage2.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { tf } from '..';
12
import { BoundingBox } from './BoundingBox';
23
import { extractImagePatches } from './extractImagePatches';
34
import { nms } from './nms';
@@ -26,8 +27,12 @@ export async function stage2(
2627
)
2728
stats.stage2_rnet = Date.now() - ts
2829

29-
const scoreDatas = await Promise.all(rnetOuts.map(out => out.scores.data()))
30-
const scores = scoreDatas.map(arr => Array.from(arr)).reduce((all, arr) => all.concat(arr))
30+
const scoresTensor = rnetOuts.length > 1
31+
? tf.concat(rnetOuts.map(out => out.scores))
32+
: rnetOuts[0].scores
33+
const scores = Array.from(await scoresTensor.data())
34+
scoresTensor.dispose()
35+
3136
const indices = scores
3237
.map((score, idx) => ({ score, idx }))
3338
.filter(c => c.score > scoreThreshold)

src/mtcnn/stage3.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { extractImagePatches } from './extractImagePatches';
44
import { nms } from './nms';
55
import { ONet } from './ONet';
66
import { ONetParams } from './types';
7+
import { tf } from '..';
78

89
export async function stage3(
910
img: HTMLCanvasElement,
@@ -27,8 +28,12 @@ export async function stage3(
2728
)
2829
stats.stage3_onet = Date.now() - ts
2930

30-
const scoreDatas = await Promise.all(onetOuts.map(out => out.scores.data()))
31-
const scores = scoreDatas.map(arr => Array.from(arr)).reduce((all, arr) => all.concat(arr))
31+
const scoresTensor = onetOuts.length > 1
32+
? tf.concat(onetOuts.map(out => out.scores))
33+
: onetOuts[0].scores
34+
const scores = Array.from(await scoresTensor.data())
35+
scoresTensor.dispose()
36+
3237
const indices = scores
3338
.map((score, idx) => ({ score, idx }))
3439
.filter(c => c.score > scoreThreshold)

0 commit comments

Comments
 (0)