Skip to content

Commit 0e899bd

Browse files
extract image patch data in bgr format to avoid channel swapping on gpu
1 parent 08aae43 commit 0e899bd

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

src/mtcnn/extractImagePatches.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { Dimensions } from '../types';
44
import { createCanvas, getContext2dOrThrow } from '../utils';
5-
import { bgrToRgbTensor } from './bgrToRgbTensor';
65
import { BoundingBox } from './BoundingBox';
76
import { normalize } from './normalize';
87

@@ -35,8 +34,10 @@ export async function extractImagePatches(
3534
const { data } = patchCtx.getImageData(0, 0, width, height)
3635

3736
const currData = []
38-
for(let i = 0; i < data.length; i++) {
39-
if ((i + 1) % 4 === 0) continue
37+
// RGBA -> BGR
38+
for(let i = 0; i < data.length; i+=4) {
39+
currData.push(data[i + 2])
40+
currData.push(data[i + 1])
4041
currData.push(data[i])
4142
}
4243
imagePatchesDatas.push(currData)
@@ -45,10 +46,10 @@ export async function extractImagePatches(
4546

4647
return imagePatchesDatas.map(data => {
4748
const t = tf.tidy(() => {
48-
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
49+
const imagePatchTensor = tf.transpose(
4950
tf.tensor4d(data, [1, width, height, 3]),
5051
[0, 2, 1, 3]
51-
).toFloat()) as tf.Tensor4D
52+
).toFloat() as tf.Tensor4D
5253

5354
return normalize(imagePatchTensor)
5455
})

src/mtcnn/stage1.ts

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,26 @@ export function stage1(
7373
) {
7474
stats.stage1 = []
7575

76-
const boxesForScale = scales.map((scale) => {
76+
const pnetOutputs = scales.map((scale) => tf.tidy(() => {
7777
const statsForScale: any = { scale }
78+
const resized = rescaleAndNormalize(imgTensor, scale)
7879

79-
const { scoresTensor, regionsTensor } = tf.tidy(() => {
80-
const resized = rescaleAndNormalize(imgTensor, scale)
81-
82-
let ts = Date.now()
83-
const { prob, regions } = PNet(resized, params)
84-
statsForScale.pnet = Date.now() - ts
80+
let ts = Date.now()
81+
const { prob, regions } = PNet(resized, params)
82+
statsForScale.pnet = Date.now() - ts
8583

86-
const scoresTensor = tf.unstack(tf.unstack(prob, 3)[1])[0] as tf.Tensor2D
87-
const regionsTensor = tf.unstack(regions)[0] as tf.Tensor3D
84+
const scoresTensor = tf.unstack(tf.unstack(prob, 3)[1])[0] as tf.Tensor2D
85+
const regionsTensor = tf.unstack(regions)[0] as tf.Tensor3D
8886

89-
return {
90-
scoresTensor,
91-
regionsTensor
92-
}
93-
})
87+
return {
88+
scoresTensor,
89+
regionsTensor,
90+
scale,
91+
statsForScale
92+
}
93+
}))
9494

95+
const boxesForScale = pnetOutputs.map(({ scoresTensor, regionsTensor, scale, statsForScale }) => {
9596
const boundingBoxes = extractBoundingBoxes(
9697
scoresTensor,
9798
regionsTensor,
@@ -121,7 +122,7 @@ export function stage1(
121122
})
122123

123124
const allBoxes = boxesForScale.reduce(
124-
(all, boxes) => all.concat(boxes)
125+
(all, boxes) => all.concat(boxes), []
125126
)
126127

127128
let finalBoxes: BoundingBox[] = []

0 commit comments

Comments
 (0)