Skip to content

Commit 18f23a1

Browse files
fixed memory leaks + accept Tensors and HTMLCanvasElement as inputs
1 parent 3ca0f4a commit 18f23a1

File tree

3 files changed

+123
-46
lines changed

3 files changed

+123
-46
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { FaceDetectionNet } from './types';
2+
3+
export class FaceDetectionResult {
4+
private score: number
5+
private top: number
6+
private left: number
7+
private bottom: number
8+
private right: number
9+
10+
constructor(
11+
score: number,
12+
top: number,
13+
left: number,
14+
bottom: number,
15+
right: number
16+
) {
17+
this.score = score
18+
this.top = Math.max(0, top),
19+
this.left = Math.max(0, left),
20+
this.bottom = Math.min(1.0, bottom),
21+
this.right = Math.min(1.0, right)
22+
}
23+
24+
public forSize(width: number, height: number): FaceDetectionNet.Detection {
25+
return {
26+
score: this.score,
27+
box: {
28+
top: this.top * height,
29+
left: this.left * width,
30+
bottom: this.bottom * height,
31+
right: this.right * width
32+
}
33+
}
34+
}
35+
}

src/faceDetectionNet/index.ts

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { isFloat } from '../utils';
44
import { extractParams } from './extractParams';
5+
import { FaceDetectionResult } from './FaceDetectionResult';
56
import { mobileNetV1 } from './mobileNetV1';
6-
import { resizeLayer } from './resizeLayer';
7-
import { predictionLayer } from './predictionLayer';
8-
import { outputLayer } from './outputLayer';
97
import { nonMaxSuppression } from './nonMaxSuppression';
10-
import { FaceDetectionNet } from './types';
8+
import { outputLayer } from './outputLayer';
9+
import { predictionLayer } from './predictionLayer';
10+
import { resizeLayer } from './resizeLayer';
1111

1212
function fromData(input: number[]): tf.Tensor4D {
1313
const pxPerChannel = input.length / 3
@@ -21,34 +21,53 @@ function fromData(input: number[]): tf.Tensor4D {
2121
}
2222

2323
function fromImageData(input: ImageData[]) {
24-
const idx = input.findIndex(data => !(data instanceof ImageData))
25-
if (idx !== -1) {
26-
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
27-
}
24+
return tf.tidy(() => {
25+
const idx = input.findIndex(data => !(data instanceof ImageData))
26+
if (idx !== -1) {
27+
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
28+
}
2829

29-
const imgTensors = input
30-
.map(data => tf.fromPixels(data))
31-
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
30+
const imgTensors = input
31+
.map(data => tf.fromPixels(data))
32+
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
3233

33-
return tf.cast(tf.concat(imgTensors, 0), 'float32')
34+
return tf.cast(tf.concat(imgTensors, 0), 'float32')
35+
})
3436
}
3537

3638
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
37-
const [_, height, width] = imgTensor.shape
38-
if (height === width) {
39-
return imgTensor
40-
}
39+
return tf.tidy(() => {
4140

42-
if (height > width) {
43-
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
44-
return tf.concat([imgTensor, pad], 2)
45-
}
46-
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
47-
return tf.concat([imgTensor, pad], 1)
41+
const [_, height, width] = imgTensor.shape
42+
if (height === width) {
43+
return imgTensor
44+
}
45+
46+
if (height > width) {
47+
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
48+
return tf.concat([imgTensor, pad], 2)
49+
}
50+
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
51+
return tf.concat([imgTensor, pad], 1)
52+
})
4853
}
4954

50-
function getImgTensor(input: ImageData|ImageData[]|number[]) {
55+
function getImgTensor(input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[]) {
5156
return tf.tidy(() => {
57+
if (input instanceof HTMLCanvasElement) {
58+
return tf.cast(
59+
tf.expandDims(tf.fromPixels(input), 0), 'float32'
60+
) as tf.Tensor4D
61+
}
62+
if (input instanceof tf.Tensor) {
63+
const rank = input.shape.length
64+
if (rank !== 3 && rank !== 4) {
65+
throw new Error('input tensor must be of rank 3 or 4')
66+
}
67+
return tf.cast(
68+
rank === 3 ? tf.expandDims(input, 0) : input, 'float32'
69+
) as tf.Tensor4D
70+
}
5271

5372
const imgDataArray = input instanceof ImageData
5473
? [input]
@@ -58,11 +77,9 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) {
5877
: null
5978
)
6079

61-
return padToSquare(
62-
imgDataArray !== null
63-
? fromImageData(imgDataArray)
64-
: fromData(input as number[])
65-
)
80+
return imgDataArray !== null
81+
? fromImageData(imgDataArray)
82+
: fromData(input as number[])
6683

6784
})
6885
}
@@ -85,31 +102,47 @@ export function faceDetectionNet(weights: Float32Array) {
85102
})
86103
}
87104

88-
function forward(input: ImageData|ImageData[]|number[]) {
105+
function forward(input: tf.Tensor|ImageData|ImageData[]|number[]) {
89106
return tf.tidy(
90107
() => forwardTensor(padToSquare(getImgTensor(input)))
91108
)
92109
}
93110

94111
async function locateFaces(
95-
input: ImageData|ImageData[]|number[],
112+
input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[],
96113
minConfidence: number = 0.8,
97114
maxResults: number = 100,
98-
): Promise<FaceDetectionNet.Detection[]> {
99-
const imgTensor = getImgTensor(input)
100-
const [_, height, width] = imgTensor.shape
115+
): Promise<FaceDetectionResult[]> {
116+
117+
let paddedHeightRelative = 1, paddedWidthRelative = 1
101118

102119
const {
103120
boxes: _boxes,
104121
scores: _scores
105-
} = forwardTensor(imgTensor)
122+
} = tf.tidy(() => {
123+
124+
let imgTensor = getImgTensor(input)
125+
const [_, height, width] = imgTensor.shape
126+
127+
imgTensor = padToSquare(imgTensor)
128+
paddedHeightRelative = imgTensor.shape[1] / height
129+
paddedWidthRelative = imgTensor.shape[2] / width
130+
131+
return forwardTensor(imgTensor)
132+
})
106133

107134
// TODO batches
108135
const boxes = _boxes[0]
109136
const scores = _scores[0]
137+
for (let i = 1; i < _boxes.length; i++) {
138+
_boxes[i].dispose()
139+
_scores[i].dispose()
140+
}
110141

111142
// TODO find a better way to filter by minConfidence
143+
//const ts = Date.now()
112144
const scoresData = Array.from(await scores.data())
145+
//console.log('await data:', (Date.now() - ts))
113146

114147
const iouThreshold = 0.5
115148
const indices = nonMaxSuppression(
@@ -120,17 +153,19 @@ export function faceDetectionNet(weights: Float32Array) {
120153
minConfidence
121154
)
122155

123-
return indices
124-
.map(idx => ({
125-
score: scoresData[idx],
126-
box: {
127-
top: Math.max(0, height * boxes.get(idx, 0)),
128-
left: Math.max(0, width * boxes.get(idx, 1)),
129-
bottom: Math.min(height, height * boxes.get(idx, 2)),
130-
right: Math.min(width, width * boxes.get(idx, 3))
131-
}
132-
}))
156+
const results = indices
157+
.map(idx => new FaceDetectionResult(
158+
scoresData[idx],
159+
boxes.get(idx, 0) * paddedHeightRelative,
160+
boxes.get(idx, 1) * paddedWidthRelative,
161+
boxes.get(idx, 2) * paddedHeightRelative,
162+
boxes.get(idx, 3) * paddedWidthRelative
163+
))
164+
165+
boxes.dispose()
166+
scores.dispose()
133167

168+
return results
134169
}
135170

136171
return {

src/utils.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingContext2
1515
return ctx
1616
}
1717

18+
function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
19+
if (media instanceof HTMLVideoElement) {
20+
return { width: media.videoWidth, height: media.videoHeight }
21+
}
22+
return media
23+
}
24+
1825
export function isFloat(num: number) {
1926
return num % 1 !== 0
2027
}
@@ -43,7 +50,7 @@ export function drawMediaToCanvas(
4350
throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement')
4451
}
4552

46-
const { width, height } = dims || media
53+
const { width, height } = dims || getMediaDimensions(media)
4754
canvas.width = width
4855
canvas.height = height
4956

@@ -59,7 +66,7 @@ export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dim
5966

6067
const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
6168

62-
const { width, height } = dims || media
69+
const { width, height } = dims || getMediaDimensions(media)
6370
return ctx.getImageData(0, 0, width, height)
6471
}
6572

0 commit comments

Comments
 (0)