Skip to content

Commit 9ffe1cc

Browse files
finished stage1
1 parent 33296ef commit 9ffe1cc

File tree

4 files changed

+208
-66
lines changed

4 files changed

+208
-66
lines changed

src/mtcnn/BoundingBox.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
export class BoundingBox {
2+
constructor(
3+
private _left: number,
4+
private _top: number,
5+
private _right: number,
6+
private _bottom: number
7+
) {}
8+
9+
10+
public get left() : number {
11+
return this._left
12+
}
13+
14+
public get top() : number {
15+
return this._top
16+
}
17+
18+
public get right() : number {
19+
return this._right
20+
}
21+
22+
public get bottom() : number {
23+
return this._bottom
24+
}
25+
26+
public get width() : number {
27+
return this.right - this.left
28+
}
29+
30+
public get height() : number {
31+
return this.bottom - this.top
32+
}
33+
34+
35+
public toSquare(): BoundingBox {
36+
let { left, top, right, bottom } = this
37+
38+
const off = (Math.abs(this.width - this.height) / 2)
39+
if (this.width < this.height) {
40+
left -= off
41+
right += off
42+
}
43+
if (this.height < this.width) {
44+
top -= off
45+
bottom += off
46+
}
47+
return new BoundingBox(left, top, right, bottom)
48+
}
49+
50+
public round(): BoundingBox {
51+
return new BoundingBox(
52+
Math.round(this.left),
53+
Math.round(this.top),
54+
Math.round(this.right),
55+
Math.round(this.bottom)
56+
)
57+
}
58+
}

src/mtcnn/Mtcnn.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,18 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
4343
})
4444
}
4545

46-
public async forward(input: TNetInput): Promise<tf.Tensor2D> {
47-
return this.forwardInput(await toNetInput(input, true))
46+
public async forward(
47+
input: TNetInput,
48+
minFaceSize: number = 20,
49+
scaleFactor: number = 0.709,
50+
scoreThresholds: number[] = [0.6, 0.7, 0.7]
51+
): Promise<tf.Tensor2D> {
52+
return this.forwardInput(
53+
await toNetInput(input, true),
54+
minFaceSize,
55+
scaleFactor,
56+
scoreThresholds
57+
)
4858
}
4959

5060
protected extractParams(weights: Float32Array) {

src/mtcnn/nms.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import { BoundingBox } from './BoundingBox';
2+
3+
export function nms(
4+
boxes: BoundingBox[],
5+
scores: number[],
6+
iouThreshold: number,
7+
isIOU: boolean = true
8+
): number[] {
9+
10+
const areas = boxes.map(
11+
box => (box.width + 1) * (box.height + 1)
12+
)
13+
14+
let indicesSortedByScore = scores
15+
.map((score, boxIndex) => ({ score, boxIndex }))
16+
.sort((c1, c2) => c1.score - c2.score)
17+
.map(c => c.boxIndex)
18+
19+
const pick: number[] = []
20+
21+
while(indicesSortedByScore.length > 0) {
22+
const curr = indicesSortedByScore.pop() as number
23+
pick.push(curr)
24+
25+
const indices = indicesSortedByScore
26+
27+
const outputs: number[] = []
28+
for (let i = 0; i < indices.length; i++) {
29+
const idx = indices[i]
30+
31+
const currBox = boxes[curr]
32+
const idxBox = boxes[idx]
33+
34+
const width = Math.max(0.0, Math.min(currBox.right, idxBox.right) - Math.max(currBox.left, idxBox.left) + 1)
35+
const height = Math.max(0.0, Math.min(currBox.bottom, idxBox.bottom) - Math.max(currBox.top, idxBox.top) + 1)
36+
const interSection = width * height
37+
38+
const out = isIOU
39+
? interSection / (areas[curr] + areas[idx] - interSection)
40+
: interSection / Math.min(areas[curr], areas[idx])
41+
42+
outputs.push(out)
43+
}
44+
45+
indicesSortedByScore = indicesSortedByScore.filter(
46+
(_, j) => outputs[j] <= iouThreshold
47+
)
48+
}
49+
50+
return pick
51+
52+
}

src/mtcnn/stage1.ts

Lines changed: 86 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

33
import { Point } from '../Point';
4+
import { BoundingBox } from './BoundingBox';
45
import { CELL_SIZE, CELL_STRIDE } from './config';
6+
import { nms } from './nms';
57
import { PNet } from './PNet';
68
import { PNetParams } from './types';
79

@@ -18,96 +20,116 @@ function rescaleAndNormalize(x: tf.Tensor4D, scale: number): tf.Tensor4D {
1820

1921

2022
function extractBoundingBoxes(
21-
scores: tf.Tensor2D,
22-
regions: tf.Tensor3D,
23+
scoresTensor: tf.Tensor2D,
24+
regionsTensor: tf.Tensor3D,
2325
scale: number,
2426
scoreThreshold: number
2527
) {
2628

2729
// TODO: fix this!, maybe better to use tf.gather here
28-
const indices2D: Point[] = []
29-
for (let y = 0; y < scores.shape[0]; y++) {
30-
for (let x = 0; x < scores.shape[1]; x++) {
31-
if (scores.get(y, x) >= scoreThreshold) {
32-
indices2D.push(new Point(x, y))
30+
const indices: Point[] = []
31+
for (let y = 0; y < scoresTensor.shape[0]; y++) {
32+
for (let x = 0; x < scoresTensor.shape[1]; x++) {
33+
if (scoresTensor.get(y, x) >= scoreThreshold) {
34+
indices.push(new Point(x, y))
3335
}
3436
}
3537
}
3638

37-
if (!indices2D.length) {
38-
return null
39-
}
40-
41-
return tf.tidy(() => {
42-
43-
const indicesTensor = tf.tensor2d(
44-
indices2D.map(pt => [pt.y, pt.x]),
45-
[indices2D.length, 2]
39+
const boundingBoxes = indices.map(idx => {
40+
const cell = new BoundingBox(
41+
Math.round((idx.x * CELL_STRIDE + 1) / scale),
42+
Math.round((idx.y * CELL_STRIDE + 1) / scale),
43+
Math.round((idx.x * CELL_STRIDE + CELL_SIZE) / scale),
44+
Math.round((idx.y * CELL_STRIDE + CELL_SIZE) / scale)
4645
)
4746

48-
const cellsStart = tf.round(
49-
indicesTensor.mul(tf.scalar(CELL_STRIDE)).add(tf.scalar(1)).div(tf.scalar(scale))
50-
) as tf.Tensor2D
51-
const cellsEnd = tf.round(
52-
indicesTensor.mul(tf.scalar(CELL_STRIDE)).add(tf.scalar(CELL_SIZE)).div(tf.scalar(scale))
53-
) as tf.Tensor2D
54-
55-
const scoresTensor = tf.tensor1d(indices2D.map(pt => scores.get(pt.y, pt.x)))
56-
57-
const candidateRegions = indices2D.map(c => ({
58-
left: regions.get(c.y, c.x, 0),
59-
top: regions.get(c.y, c.x, 1),
60-
right: regions.get(c.y, c.x, 2),
61-
bottom: regions.get(c.y, c.x, 3)
62-
}))
63-
64-
const regionsTensor = tf.tensor2d(
65-
candidateRegions.map(r => [r.left, r.top, r.right, r.bottom]),
66-
[candidateRegions.length, 4]
67-
)
47+
const score = scoresTensor.get(idx.y, idx.x)
6848

69-
const boxesTensor = tf.concat2d([cellsStart, cellsEnd, scoresTensor.as2D(scoresTensor.size, 1), regionsTensor], 1)
49+
const region = new BoundingBox(
50+
regionsTensor.get(idx.y, idx.x, 0),
51+
regionsTensor.get(idx.y, idx.x, 1),
52+
regionsTensor.get(idx.y, idx.x, 2),
53+
regionsTensor.get(idx.y, idx.x, 3)
54+
)
7055

71-
return boxesTensor
56+
return {
57+
cell,
58+
score,
59+
region
60+
}
7261
})
73-
}
74-
75-
// TODO: debug
76-
declare const window: any
7762

78-
export function stage1(x: tf.Tensor4D, scales: number[], scoreThreshold: number, params: PNetParams) {
79-
return tf.tidy(() => {
63+
return boundingBoxes
64+
}
8065

81-
const boxes = scales.map((scale, i) => {
82-
let resized = i === 0
83-
// TODO: debug
84-
? tf.tensor4d(window.resizedData, [1, 820, 461, 3])
66+
export function stage1(imgTensor: tf.Tensor4D, scales: number[], scoreThreshold: number, params: PNetParams) {
8567

86-
: rescaleAndNormalize(x, scale)
68+
const boxesForScale = scales.map((scale) => {
8769

70+
const { scoresTensor, regionsTensor } = tf.tidy(() => {
71+
const resized = rescaleAndNormalize(imgTensor, scale)
8872
const { prob, regions } = PNet(resized, params)
8973

9074
const scores = tf.unstack(prob, 3)[1]
9175
const [sh, sw] = scores.shape.slice(1)
9276
const [rh, rw] = regions.shape.slice(1)
9377

78+
return {
79+
scoresTensor: scores.as2D(sh, sw),
80+
regionsTensor: regions.as3D(rh, rw, 4)
81+
}
82+
})
9483

95-
const boxes = extractBoundingBoxes(
96-
scores.as2D(sh, sw),
97-
regions.as3D(rh, rw, 4),
98-
scale,
99-
scoreThreshold
100-
)
84+
const boundingBoxes = extractBoundingBoxes(
85+
scoresTensor,
86+
regionsTensor,
87+
scale,
88+
scoreThreshold
89+
)
10190

102-
// TODO: debug
103-
if (!boxes) {
104-
console.log('no boxes for scale', scale)
105-
return
106-
}
107-
// TODO: debug
108-
i === 0 && (window.boxes = boxes.dataSync())
91+
scoresTensor.dispose()
92+
regionsTensor.dispose()
10993

110-
})
94+
if (!boundingBoxes.length) {
95+
return []
96+
}
11197

98+
const indices = nms(
99+
boundingBoxes.map(bbox => bbox.cell),
100+
boundingBoxes.map(bbox => bbox.score),
101+
0.5
102+
)
103+
104+
return indices.map(boxIdx => boundingBoxes[boxIdx])
112105
})
113-
}
106+
107+
const allBoxes = boxesForScale.reduce(
108+
(all, boxes) => all.concat(boxes)
109+
)
110+
111+
if (allBoxes.length > 0) {
112+
const indices = nms(
113+
allBoxes.map(bbox => bbox.cell),
114+
allBoxes.map(bbox => bbox.score),
115+
0.7
116+
)
117+
118+
const finalBoxes = indices
119+
.map(idx => allBoxes[idx])
120+
.map(({ cell, region, score }) => ({
121+
box: new BoundingBox(
122+
cell.left + (region.left * cell.width),
123+
cell.right + (region.right * cell.width),
124+
cell.top + (region.top * cell.height),
125+
cell.bottom + (region.bottom * cell.height),
126+
).toSquare().round(),
127+
score
128+
}))
129+
130+
return finalBoxes
131+
}
132+
133+
return []
134+
135+
}

0 commit comments

Comments
 (0)