Skip to content

Commit 45c9d6f

Browse files
finalize architecture
1 parent 23d4664 commit 45c9d6f

File tree

3 files changed

+87
-30
lines changed

3 files changed

+87
-30
lines changed

src/faceDetectionNet/index.ts

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function fromData(input: number[]): tf.Tensor4D {
1515
throw new Error(`invalid input size: ${dim}x${dim}x3 (array length: ${input.length})`)
1616
}
1717

18-
return tf.tensor4d(input as number[], [1, 580, 580, 3])
18+
return tf.tensor4d(input as number[], [1, dim, dim, 3])
1919
}
2020

2121
function fromImageData(input: ImageData[]) {
@@ -31,24 +31,30 @@ function fromImageData(input: ImageData[]) {
3131
return tf.cast(tf.concat(imgTensors, 0), 'float32')
3232
}
3333

34+
function getImgTensor(input: ImageData|ImageData[]|number[]) {
35+
return tf.tidy(() => {
36+
37+
const imgDataArray = input instanceof ImageData
38+
? [input]
39+
: (
40+
input[0] instanceof ImageData
41+
? input as ImageData[]
42+
: null
43+
)
44+
45+
return imgDataArray !== null
46+
? fromImageData(imgDataArray)
47+
: fromData(input as number[])
48+
49+
})
50+
}
51+
3452
export function faceDetectionNet(weights: Float32Array) {
3553
const params = extractParams(weights)
3654

37-
async function forward(input: ImageData|ImageData[]|number[]) {
55+
function forwardTensor(imgTensor: tf.Tensor4D) {
3856
return tf.tidy(() => {
3957

40-
const imgDataArray = input instanceof ImageData
41-
? [input]
42-
: (
43-
input[0] instanceof ImageData
44-
? input as ImageData[]
45-
: null
46-
)
47-
48-
const imgTensor = imgDataArray !== null
49-
? fromImageData(imgDataArray)
50-
: fromData(input as number[])
51-
5258
const resized = resizeLayer(imgTensor) as tf.Tensor4D
5359
const features = mobileNetV1(resized, params.mobilenetv1_params)
5460

@@ -57,14 +63,54 @@ export function faceDetectionNet(weights: Float32Array) {
5763
classPredictions
5864
} = predictionLayer(features.out, features.conv11, params.prediction_layer_params)
5965

60-
const decoded = outputLayer(boxPredictions, classPredictions, params.output_layer_params)
66+
return outputLayer(boxPredictions, classPredictions, params.output_layer_params)
67+
})
68+
}
69+
70+
// TODO debug output
71+
function forward(input: ImageData|ImageData[]|number[]) {
72+
return tf.tidy(
73+
() => forwardTensor(getImgTensor(input))
74+
)
75+
}
6176

62-
return decoded
77+
async function locateFaces(
78+
input: ImageData|ImageData[]|number[],
79+
minConfidence: number = 0.8
80+
) {
81+
const imgTensor = getImgTensor(input)
82+
83+
const [_, height, width] = imgTensor.shape
84+
85+
const {
86+
boxes: _boxes,
87+
scores: _scores
88+
} = forwardTensor(imgTensor)
89+
90+
// TODO batches
91+
const boxes = _boxes[0]
92+
const scores = _scores[0]
93+
94+
// TODO find a better way to filter by minConfidence
95+
const data = await scores.data()
96+
97+
return Array.from(data)
98+
.map((score, idx) => ({ score, idx }))
99+
.filter(({ score }) => minConfidence < score)
100+
.map(({ score, idx }) => ({
101+
score,
102+
box: {
103+
left: Math.max(0, width * boxes.get(idx, 0)),
104+
right: Math.min(width, width * boxes.get(idx, 1)),
105+
top: Math.max(0, height * boxes.get(idx, 2)),
106+
bottom: Math.min(height, height * boxes.get(idx, 3))
107+
}
108+
}))
63109

64-
})
65110
}
66111

67112
return {
68-
forward
113+
forward,
114+
locateFaces
69115
}
70116
}

src/faceDetectionNet/outputLayer.ts

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,6 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { FaceDetectionNet } from './types';
44

5-
6-
function batchMultiClassNonMaxSuppressionLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
7-
// TODO
8-
return x0
9-
}
10-
115
function getCenterCoordinatesAndSizesLayer(x: tf.Tensor2D) {
126
const vec = tf.unstack(tf.transpose(x, [1, 0]))
137

@@ -27,7 +21,7 @@ function getCenterCoordinatesAndSizesLayer(x: tf.Tensor2D) {
2721
}
2822
}
2923

30-
function decodeLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
24+
function decodeBoxesLayer(x0: tf.Tensor2D, x1: tf.Tensor2D) {
3125
const {
3226
sizes,
3327
centers
@@ -61,15 +55,30 @@ export function outputLayer(
6155

6256
const batchSize = boxPredictions.shape[0]
6357

64-
const decoded = decodeLayer(
58+
let boxes = decodeBoxesLayer(
6559
tf.reshape(tf.tile(params.extra_dim, [batchSize, 1, 1]), [-1, 4]) as tf.Tensor2D,
6660
tf.reshape(boxPredictions, [-1, 4]) as tf.Tensor2D
6761
)
62+
boxes = tf.reshape(
63+
boxes,
64+
[batchSize, (boxes.shape[0] / batchSize), 4]
65+
)
66+
67+
const scoresAndClasses = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1]))
68+
let scores = tf.slice(scoresAndClasses, [0, 0, 0], [-1, -1, 1]) as tf.Tensor
69+
70+
scores = tf.reshape(
71+
scores,
72+
[batchSize, scores.shape[1]]
73+
)
6874

69-
const in1 = tf.sigmoid(tf.slice(classPredictions, [0, 0, 1], [-1, -1, -1]))
70-
const in2 = tf.expandDims(tf.reshape(decoded, [batchSize, 5118, 4]), 2)
75+
const boxesByBatch = tf.unstack(boxes) as tf.Tensor2D[]
76+
const scoresByBatch = tf.unstack(scores) as tf.Tensor1D[]
7177

72-
return decoded
78+
return {
79+
boxes: boxesByBatch,
80+
scores: scoresByBatch
81+
}
7382

7483
})
7584
}

src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ import { euclideanDistance } from './euclideanDistance';
22
import { faceDetectionNet } from './faceDetectionNet';
33
import { faceRecognitionNet } from './faceRecognitionNet';
44
import { normalize } from './normalize';
5+
import * as tf from '@tensorflow/tfjs-core';
56

67
export {
78
euclideanDistance,
89
faceDetectionNet,
910
faceRecognitionNet,
10-
normalize
11+
normalize,
12+
tf
1113
}

0 commit comments

Comments
 (0)