Skip to content

Commit 5e957ec

Browse files
implemented stage2
1 parent 084cd30 commit 5e957ec

File tree

11 files changed

+223
-28
lines changed

11 files changed

+223
-28
lines changed

src/NetInput.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { createCanvasFromMedia } from './utils';
88

99
export class NetInput {
1010
private _inputs: tf.Tensor3D[] = []
11+
private _canvases: HTMLCanvasElement[] = []
1112
private _isManaged: boolean = false
1213
private _isBatchInput: boolean = false
1314

@@ -16,14 +17,15 @@ export class NetInput {
1617

1718
constructor(
1819
inputs: tf.Tensor4D | Array<TResolvedNetInput>,
19-
isBatchInput: boolean = false
20+
isBatchInput: boolean = false,
21+
keepCanvases: boolean = false
2022
) {
2123
if (isTensor4D(inputs)) {
2224
this._inputs = tf.unstack(inputs as tf.Tensor4D) as tf.Tensor3D[]
2325
}
2426

2527
if (Array.isArray(inputs)) {
26-
this._inputs = inputs.map(input => {
28+
this._inputs = inputs.map((input, idx) => {
2729
if (isTensor3D(input)) {
2830
// TODO: make sure not to dispose original tensors passed in by the user
2931
return tf.clone(input as tf.Tensor3D)
@@ -39,9 +41,11 @@ export class NetInput {
3941
return (input as tf.Tensor4D).reshape(shape.slice(1) as [number, number, number]) as tf.Tensor3D
4042
}
4143

42-
return tf.fromPixels(
43-
input instanceof HTMLCanvasElement ? input : createCanvasFromMedia(input as HTMLImageElement | HTMLVideoElement)
44-
)
44+
const canvas = input instanceof HTMLCanvasElement ? input : createCanvasFromMedia(input as HTMLImageElement | HTMLVideoElement)
45+
if (keepCanvases) {
46+
this._canvases[idx] = canvas
47+
}
48+
return tf.fromPixels(canvas)
4549
})
4650
}
4751

@@ -53,6 +57,10 @@ export class NetInput {
5357
return this._inputs
5458
}
5559

60+
public get canvases(): HTMLCanvasElement[] {
61+
return this._canvases
62+
}
63+
5664
public get isManaged(): boolean {
5765
return this._isManaged
5866
}

src/mtcnn/BoundingBox.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,38 @@ export class BoundingBox {
5555
Math.round(this.bottom)
5656
)
5757
}
58+
59+
public padAtBorders(imageHeight: number, imageWidth: number) {
60+
const w = this.width + 1
61+
const h = this.height + 1
62+
63+
let dx = 1
64+
let dy = 1
65+
let edx = w
66+
let edy = h
67+
68+
let x = this.left
69+
let y = this.top
70+
let ex = this.right
71+
let ey = this.bottom
72+
73+
if (ex > imageWidth) {
74+
edx = -ex + imageWidth + w
75+
ex = imageWidth
76+
}
77+
if (ey > imageHeight) {
78+
edy = -ey + imageHeight + h
79+
ey = imageHeight
80+
}
81+
if (x < 1) {
82+
edy = 2 - x
83+
x = 1
84+
}
85+
if (y < 1) {
86+
edy = 2 - y
87+
y = 1
88+
}
89+
90+
return { dy, edy, dx, edx, y, ey, x, ex, w, h }
91+
}
5892
}

src/mtcnn/Mtcnn.ts

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { bgrToRgbTensor } from './bgrToRgbTensor';
88
import { extractParams } from './extractParams';
99
import { pyramidDown } from './pyramidDown';
1010
import { stage1 } from './stage1';
11+
import { stage2 } from './stage2';
1112
import { NetParams } from './types';
1213

1314
export class Mtcnn extends NeuralNetwork<NetParams> {
@@ -16,31 +17,45 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
1617
super('Mtcnn')
1718
}
1819

19-
public forwardInput(
20+
public async forwardInput(
2021
input: NetInput,
2122
minFaceSize: number = 20,
2223
scaleFactor: number = 0.709,
2324
scoreThresholds: number[] = [0.6, 0.7, 0.7]
24-
): tf.Tensor2D {
25+
): Promise<tf.Tensor2D> {
2526

2627
const { params } = this
2728

2829
if (!params) {
2930
throw new Error('Mtcnn - load model before inference')
3031
}
3132

32-
return tf.tidy(() => {
33-
// TODO: expects bgr input?
34-
let imgTensor = bgrToRgbTensor(
35-
tf.expandDims(input.inputs[0]).toFloat() as tf.Tensor4D
33+
const inputTensor = input.inputs[0]
34+
const inputCanvas = input.canvases[0]
35+
36+
if (!inputCanvas) {
37+
throw new Error('Mtcnn - inputCanvas is not defined, note that passing tensors into Mtcnn.forwardInput is not supported yet.')
38+
}
39+
40+
const imgTensor = tf.tidy(() =>
41+
bgrToRgbTensor(
42+
tf.expandDims(inputTensor).toFloat() as tf.Tensor4D
3643
)
44+
)
45+
46+
const scales = pyramidDown(minFaceSize, scaleFactor, imgTensor.shape.slice(1))
47+
const out1 = await stage1(imgTensor, scales, scoreThresholds[0], params.pnet)
48+
49+
// using the inputCanvas to extract and resize the image patches, since it is faster
50+
// than doing this on the gpu
51+
const out2 = await stage2(inputCanvas, out1, scoreThresholds[1], params.rnet)
52+
3753

38-
const scales = pyramidDown(minFaceSize, scaleFactor, imgTensor.shape.slice(1))
3954

40-
const out1 = stage1(imgTensor, scales, scoreThresholds[0], params.pnet)
55+
imgTensor.dispose()
56+
input.dispose()
4157

42-
return tf.tensor2d([0], [1, 1])
43-
})
58+
return tf.tensor2d([0], [1, 1])
4459
}
4560

4661
public async forward(
@@ -50,7 +65,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
5065
scoreThresholds: number[] = [0.6, 0.7, 0.7]
5166
): Promise<tf.Tensor2D> {
5267
return this.forwardInput(
53-
await toNetInput(input, true),
68+
await toNetInput(input, true, true),
5469
minFaceSize,
5570
scaleFactor,
5671
scoreThresholds

src/mtcnn/RNet.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { fullyConnectedLayer } from '../faceLandmarkNet/fullyConnectedLayer';
4+
import { prelu } from './prelu';
5+
import { sharedLayer } from './sharedLayers';
6+
import { RNetParams } from './types';
7+
8+
export function RNet(x: tf.Tensor4D, params: RNetParams): { prob: tf.Tensor2D, regions: tf.Tensor2D } {
9+
return tf.tidy(() => {
10+
11+
const convOut = sharedLayer(x, params)
12+
const vectorized = tf.reshape(convOut, [convOut.shape[0], params.fc1.weights.shape[0]]) as tf.Tensor2D
13+
const fc1 = fullyConnectedLayer(vectorized, params.fc1)
14+
const prelu4 = prelu<tf.Tensor2D>(fc1, params.prelu4_alpha)
15+
const fc2_1 = fullyConnectedLayer(prelu4, params.fc2_1)
16+
const max = tf.expandDims(tf.max(fc2_1, 1), 1)
17+
const prob = tf.softmax(tf.sub(fc2_1, max), 1) as tf.Tensor2D
18+
const regions = fullyConnectedLayer(prelu4, params.fc2_2)
19+
20+
return { prob, regions }
21+
})
22+
}

src/mtcnn/bgrToRgbTensor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
export function bgrToRgbTensor(tensor: tf.Tensor4D): tf.Tensor4D {
44
return tf.tidy(
5-
() => tf.stack(tf.unstack(tensor, 3), 3)
5+
() => tf.stack(tf.unstack(tensor, 3).reverse(), 3)
66
) as tf.Tensor4D
77
}

src/mtcnn/normalize.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
export function normalize(x: tf.Tensor4D): tf.Tensor4D {
4+
return tf.tidy(
5+
() => tf.mul(tf.sub(x, tf.scalar(127.5)), tf.scalar(0.0078125))
6+
)
7+
}

src/mtcnn/prelu.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
export function prelu(x: tf.Tensor4D, alpha: tf.Tensor1D): tf.Tensor4D {
3+
export function prelu<T extends tf.Tensor>(x: T, alpha: tf.Tensor1D): T {
44
return tf.tidy(() =>
55
tf.add(
66
tf.relu(x),

src/mtcnn/sharedLayers.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ export function sharedLayer(x: tf.Tensor4D, params: SharedParams, isPnet: boolea
88
return tf.tidy(() => {
99

1010
let out = convLayer(x, params.conv1, 'valid')
11-
out = prelu(out, params.prelu1_alpha)
11+
out = prelu<tf.Tensor4D>(out, params.prelu1_alpha)
1212
out = tf.maxPool(out, isPnet ? [2, 2]: [3, 3], [2, 2], 'same')
1313
out = convLayer(out, params.conv2, 'valid')
14-
out = prelu(out, params.prelu2_alpha)
14+
out = prelu<tf.Tensor4D>(out, params.prelu2_alpha)
1515
out = isPnet ? out : tf.maxPool(out, [3, 3], [2, 2], 'valid')
1616
out = convLayer(out, params.conv3, 'valid')
17-
out = prelu(out, params.prelu3_alpha)
17+
out = prelu<tf.Tensor4D>(out, params.prelu3_alpha)
1818

1919
return out
2020
})

src/mtcnn/stage1.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,12 @@ export function stage1(
7979
const { prob, regions } = PNet(resized, params)
8080

8181

82-
const scores = tf.unstack(prob, 3)[1]
83-
const [sh, sw] = scores.shape.slice(1)
84-
const [rh, rw] = regions.shape.slice(1)
82+
const scoresTensor = tf.unstack(tf.unstack(prob, 3)[1])[0] as tf.Tensor2D
83+
const regionsTensor = tf.unstack(regions)[0] as tf.Tensor3D
8584

8685
return {
87-
scoresTensor: scores.as2D(sh, sw),
88-
regionsTensor: regions.as3D(rh, rw, 4)
86+
scoresTensor,
87+
regionsTensor
8988
}
9089
})
9190

src/mtcnn/stage2.ts

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { createCanvas, getContext2dOrThrow } from '../utils';
4+
import { bgrToRgbTensor } from './bgrToRgbTensor';
5+
import { BoundingBox } from './BoundingBox';
6+
import { nms } from './nms';
7+
import { normalize } from './normalize';
8+
import { RNet } from './RNet';
9+
import { RNetParams } from './types';
10+
11+
export async function stage2(
12+
img: HTMLCanvasElement,
13+
boxes: { box: BoundingBox, score: number }[],
14+
scoreThreshold: number,
15+
params: RNetParams
16+
) {
17+
18+
const { height, width } = img
19+
20+
const imgCtx = getContext2dOrThrow(img)
21+
22+
const bitmaps = await Promise.all(boxes.map(async ({ box }) => {
23+
// TODO: correct padding
24+
const { y, ey, x, ex } = box.padAtBorders(height, width)
25+
26+
const fromX = x - 1
27+
const fromY = y - 1
28+
const imgData = imgCtx.getImageData(fromX, fromY, (ex - fromX), (ey - fromY))
29+
30+
return createImageBitmap(imgData)
31+
}))
32+
33+
const imagePatchesData: number[] = []
34+
35+
bitmaps.forEach(bmp => {
36+
const patch = createCanvas({ width: 24, height: 24 })
37+
const patchCtx = getContext2dOrThrow(patch)
38+
patchCtx.drawImage(bmp, 0, 0, 24, 24)
39+
const { data } = patchCtx.getImageData(0, 0, 24, 24)
40+
41+
for(let i = 0; i < data.length; i++) {
42+
if ((i + 1) % 4 === 0) continue
43+
imagePatchesData.push(data[i])
44+
}
45+
})
46+
47+
const rnetOut = tf.tidy(() => {
48+
const imagePatchTensor = bgrToRgbTensor(tf.transpose(
49+
tf.tensor4d(imagePatchesData, [boxes.length, 24, 24, 3]),
50+
[0, 2, 1, 3]
51+
).toFloat()) as tf.Tensor4D
52+
53+
const normalized = normalize(imagePatchTensor)
54+
55+
const { prob, regions } = RNet(normalized, params)
56+
return {
57+
scores: tf.unstack(prob, 1)[1],
58+
regions
59+
}
60+
})
61+
62+
const scores = Array.from(await rnetOut.scores.data())
63+
64+
const indices = scores
65+
.map((score, idx) => ({ score, idx }))
66+
.filter(c => c.score > scoreThreshold)
67+
.map(({ idx }) => idx)
68+
69+
const filteredBoxes = indices.map(idx => boxes[idx].box)
70+
const filteredScores = indices.map(idx => scores[idx])
71+
72+
let finalBoxes: BoundingBox[] = []
73+
let finalScores: number[] = []
74+
75+
if (filteredBoxes.length > 0) {
76+
const indicesNms = nms(
77+
filteredBoxes,
78+
filteredScores,
79+
0.7
80+
)
81+
82+
finalScores = indicesNms.map(idx => filteredScores[idx])
83+
finalBoxes = indicesNms
84+
.map(idx => {
85+
const box = filteredBoxes[idx]
86+
const [rleft, rtop, right, rbottom] = [
87+
rnetOut.regions.get(indices[idx], 0),
88+
rnetOut.regions.get(indices[idx], 1),
89+
rnetOut.regions.get(indices[idx], 2),
90+
rnetOut.regions.get(indices[idx], 3)
91+
]
92+
93+
return new BoundingBox(
94+
box.left + (rleft * box.width),
95+
box.top + (rtop * box.height),
96+
box.right + (right * box.width),
97+
box.bottom + (rbottom * box.height)
98+
).toSquare().round()
99+
})
100+
}
101+
102+
rnetOut.regions.dispose()
103+
rnetOut.scores.dispose()
104+
105+
return {
106+
finalBoxes,
107+
finalScores
108+
}
109+
}

src/toNetInput.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ import { awaitMediaLoaded, resolveInput } from './utils';
1717
*/
1818
export async function toNetInput(
1919
inputs: TNetInput,
20-
manageCreatedInput: boolean = false
20+
manageCreatedInput: boolean = false,
21+
keepCanvases: boolean = false
2122
): Promise<NetInput> {
2223
if (inputs instanceof NetInput) {
2324
return inputs
@@ -67,5 +68,5 @@ export async function toNetInput(
6768
inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input))
6869
)
6970

70-
return afterCreate(new NetInput(inputArray, Array.isArray(inputs)))
71+
return afterCreate(new NetInput(inputArray, Array.isArray(inputs), keepCanvases))
7172
}

0 commit comments

Comments
 (0)