Skip to content

Commit 2ae7636

Browse files
NetInput to simplify api
1 parent b23e637 commit 2ae7636

File tree

10 files changed

+172
-116
lines changed

10 files changed

+172
-116
lines changed

src/NetInput.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { TMediaElement, TNetInput } from './types';
2+
import { Dimensions, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
3+
4+
export class NetInput {
5+
private _canvases: HTMLCanvasElement[]
6+
7+
constructor(
8+
mediaArg: TNetInput,
9+
dims?: Dimensions
10+
) {
11+
const mediaArgArray = Array.isArray(mediaArg)
12+
? mediaArg
13+
: [mediaArg]
14+
15+
if (!mediaArgArray.length) {
16+
throw new Error('NetInput - empty array passed as input')
17+
}
18+
19+
const medias = mediaArgArray.map(getElement)
20+
21+
22+
medias.forEach((media, i) => {
23+
if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement || media instanceof HTMLCanvasElement)) {
24+
const idxHint = Array.isArray(mediaArg) ? ` at input index ${i}:` : ''
25+
if (typeof mediaArgArray[i] === 'string') {
26+
throw new Error(`NetInput -${idxHint} string passed, but could not resolve HTMLElement for element id`)
27+
}
28+
throw new Error(`NetInput -${idxHint} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, or to be an element id`)
29+
}
30+
})
31+
32+
this._canvases = []
33+
medias.forEach(m => this.initCanvas(m, dims))
34+
}
35+
36+
private initCanvas(media: TMediaElement, dims?: Dimensions) {
37+
if (media instanceof HTMLCanvasElement) {
38+
this._canvases.push(media)
39+
return
40+
}
41+
42+
// if input is batch type, make sure every canvas has the same dimensions
43+
const { width, height } = this.dims || dims || getMediaDimensions(media)
44+
45+
const canvas = document.createElement('canvas')
46+
canvas.width = width
47+
canvas.height = height
48+
49+
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
50+
this._canvases.push(canvas)
51+
}
52+
53+
public get canvases() : HTMLCanvasElement[] {
54+
return this._canvases
55+
}
56+
57+
public get width() : number {
58+
return (this._canvases[0] || {}).width
59+
}
60+
61+
public get height() : number {
62+
return (this._canvases[0] || {}).height
63+
}
64+
65+
public get dims() : Dimensions | null {
66+
const { width, height } = this
67+
return (width > 0 && height > 0) ? { width, height } : null
68+
}
69+
}

src/faceDetectionNet/boxPredictionLayer.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@ function convWithBias(
1616

1717
export function boxPredictionLayer(
1818
x: tf.Tensor4D,
19-
params: FaceDetectionNet.BoxPredictionParams,
20-
size: number
19+
params: FaceDetectionNet.BoxPredictionParams
2120
) {
2221
return tf.tidy(() => {
2322

2423
const batchSize = x.shape[0]
2524

2625
const boxPredictionEncoding = tf.reshape(
2726
convWithBias(x, params.box_encoding_predictor_params),
28-
[batchSize, size, 1, 4]
27+
[batchSize, -1, 1, 4]
2928
)
3029
const classPrediction = tf.reshape(
3130
convWithBias(x, params.class_predictor_params),
32-
[batchSize, size, 3]
31+
[batchSize, -1, 3]
3332
)
3433

3534
return {

src/faceDetectionNet/index.ts

Lines changed: 7 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { isFloat } from '../utils';
3+
import { NetInput } from '../NetInput';
4+
import { getImageTensor, padToSquare } from '../transformInputs';
5+
import { TNetInput } from '../types';
46
import { extractParams } from './extractParams';
57
import { FaceDetectionResult } from './FaceDetectionResult';
68
import { mobileNetV1 } from './mobileNetV1';
@@ -9,81 +11,6 @@ import { outputLayer } from './outputLayer';
911
import { predictionLayer } from './predictionLayer';
1012
import { resizeLayer } from './resizeLayer';
1113

12-
function fromData(input: number[]): tf.Tensor4D {
13-
const pxPerChannel = input.length / 3
14-
const dim = Math.sqrt(pxPerChannel)
15-
16-
if (isFloat(dim)) {
17-
throw new Error(`invalid input size: ${dim}x${dim}x3 (array length: ${input.length})`)
18-
}
19-
20-
return tf.tensor4d(input as number[], [1, dim, dim, 3])
21-
}
22-
23-
function fromImageData(input: ImageData[]) {
24-
return tf.tidy(() => {
25-
const idx = input.findIndex(data => !(data instanceof ImageData))
26-
if (idx !== -1) {
27-
throw new Error(`expected input at index ${idx} to be instanceof ImageData`)
28-
}
29-
30-
const imgTensors = input
31-
.map(data => tf.fromPixels(data))
32-
.map(data => tf.expandDims(data, 0)) as tf.Tensor4D[]
33-
34-
return tf.cast(tf.concat(imgTensors, 0), 'float32')
35-
})
36-
}
37-
38-
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
39-
return tf.tidy(() => {
40-
41-
const [_, height, width] = imgTensor.shape
42-
if (height === width) {
43-
return imgTensor
44-
}
45-
46-
if (height > width) {
47-
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
48-
return tf.concat([imgTensor, pad], 2)
49-
}
50-
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
51-
return tf.concat([imgTensor, pad], 1)
52-
})
53-
}
54-
55-
function getImgTensor(input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[]) {
56-
return tf.tidy(() => {
57-
if (input instanceof HTMLCanvasElement) {
58-
return tf.cast(
59-
tf.expandDims(tf.fromPixels(input), 0), 'float32'
60-
) as tf.Tensor4D
61-
}
62-
if (input instanceof tf.Tensor) {
63-
const rank = input.shape.length
64-
if (rank !== 3 && rank !== 4) {
65-
throw new Error('input tensor must be of rank 3 or 4')
66-
}
67-
return tf.cast(
68-
rank === 3 ? tf.expandDims(input, 0) : input, 'float32'
69-
) as tf.Tensor4D
70-
}
71-
72-
const imgDataArray = input instanceof ImageData
73-
? [input]
74-
: (
75-
input[0] instanceof ImageData
76-
? input as ImageData[]
77-
: null
78-
)
79-
80-
return imgDataArray !== null
81-
? fromImageData(imgDataArray)
82-
: fromData(input as number[])
83-
84-
})
85-
}
86-
8714
export function faceDetectionNet(weights: Float32Array) {
8815
const params = extractParams(weights)
8916

@@ -102,14 +29,14 @@ export function faceDetectionNet(weights: Float32Array) {
10229
})
10330
}
10431

105-
function forward(input: tf.Tensor|ImageData|ImageData[]|number[]) {
32+
function forward(input: tf.Tensor | NetInput | TNetInput) {
10633
return tf.tidy(
107-
() => forwardTensor(padToSquare(getImgTensor(input)))
34+
() => forwardTensor(padToSquare(getImageTensor(input)))
10835
)
10936
}
11037

11138
async function locateFaces(
112-
input: tf.Tensor|HTMLCanvasElement|ImageData|ImageData[]|number[],
39+
input: tf.Tensor | NetInput,
11340
minConfidence: number = 0.8,
11441
maxResults: number = 100,
11542
): Promise<FaceDetectionResult[]> {
@@ -121,7 +48,7 @@ export function faceDetectionNet(weights: Float32Array) {
12148
scores: _scores
12249
} = tf.tidy(() => {
12350

124-
let imgTensor = getImgTensor(input)
51+
let imgTensor = getImageTensor(input)
12552
const [_, height, width] = imgTensor.shape
12653

12754
imgTensor = padToSquare(imgTensor)
@@ -140,9 +67,7 @@ export function faceDetectionNet(weights: Float32Array) {
14067
}
14168

14269
// TODO find a better way to filter by minConfidence
143-
//const ts = Date.now()
14470
const scoresData = Array.from(await scores.data())
145-
//console.log('await data:', (Date.now() - ts))
14671

14772
const iouThreshold = 0.5
14873
const indices = nonMaxSuppression(

src/faceDetectionNet/predictionLayer.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ import { boxPredictionLayer } from './boxPredictionLayer';
44
import { pointwiseConvLayer } from './pointwiseConvLayer';
55
import { FaceDetectionNet } from './types';
66

7-
export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: FaceDetectionNet.PredictionLayerParams) {
7+
export function predictionLayer(
8+
x: tf.Tensor4D,
9+
conv11: tf.Tensor4D,
10+
params: FaceDetectionNet.PredictionLayerParams
11+
) {
812
return tf.tidy(() => {
913

1014
const conv0 = pointwiseConvLayer(x, params.conv_0_params, [1, 1])
@@ -16,12 +20,12 @@ export function predictionLayer(x: tf.Tensor4D, conv11: tf.Tensor4D, params: Fac
1620
const conv6 = pointwiseConvLayer(conv5, params.conv_6_params, [1, 1])
1721
const conv7 = pointwiseConvLayer(conv6, params.conv_7_params, [2, 2])
1822

19-
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params, 3072)
20-
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params, 1536)
21-
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params, 384)
22-
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params, 96)
23-
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params, 24)
24-
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params, 6)
23+
const boxPrediction0 = boxPredictionLayer(conv11, params.box_predictor_0_params)
24+
const boxPrediction1 = boxPredictionLayer(x, params.box_predictor_1_params)
25+
const boxPrediction2 = boxPredictionLayer(conv1, params.box_predictor_2_params)
26+
const boxPrediction3 = boxPredictionLayer(conv3, params.box_predictor_3_params)
27+
const boxPrediction4 = boxPredictionLayer(conv5, params.box_predictor_4_params)
28+
const boxPrediction5 = boxPredictionLayer(conv7, params.box_predictor_5_params)
2529

2630
const boxPredictions = tf.concat([
2731
boxPrediction0.boxPredictionEncoding,

src/faceRecognitionNet/index.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
import { normalize } from '../normalize';
3+
import { NetInput } from '../NetInput';
4+
import { getImageTensor, padToSquare } from '../transformInputs';
5+
import { TNetInput } from '../types';
46
import { convDown } from './convLayer';
57
import { extractParams } from './extractParams';
8+
import { normalize } from './normalize';
69
import { residual, residualDown } from './residualLayer';
710

811
export function faceRecognitionNet(weights: Float32Array) {
912
const params = extractParams(weights)
1013

11-
function forward(input: number[] | ImageData) {
12-
14+
function forward(input: tf.Tensor | NetInput | TNetInput) {
1315
return tf.tidy(() => {
1416

15-
const x = normalize(input)
17+
const x = normalize(padToSquare(getImageTensor(input)))
1618

1719
let out = convDown(x, params.conv32_down)
1820
out = tf.maxPool(out, 3, 2, 'valid')
@@ -42,14 +44,14 @@ export function faceRecognitionNet(weights: Float32Array) {
4244
})
4345
}
4446

45-
const computeFaceDescriptor = async (input: number[] | ImageData) => {
47+
const computeFaceDescriptor = async (input: tf.Tensor | NetInput | TNetInput) => {
4648
const result = forward(input)
4749
const data = await result.data()
4850
result.dispose()
4951
return data
5052
}
5153

52-
const computeFaceDescriptorSync = (input: number[] | ImageData) => {
54+
const computeFaceDescriptorSync = (input: tf.Tensor | NetInput | TNetInput) => {
5355
const result = forward(input)
5456
const data = result.dataSync()
5557
result.dispose()
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3-
export function normalize(input: number[] | ImageData): tf.Tensor4D {
3+
export function normalize(x: tf.Tensor4D): tf.Tensor4D {
44
return tf.tidy(() => {
55
const avg_r = tf.fill([1, 150, 150, 1], 122.782);
66
const avg_g = tf.fill([1, 150, 150, 1], 117.001);
77
const avg_b = tf.fill([1, 150, 150, 1], 104.298);
88
const avg_rgb = tf.concat([avg_r, avg_g, avg_b], 3)
99

10-
const x = input instanceof ImageData
11-
? tf.cast(tf.reshape(tf.fromPixels(input), [1, 150, 150, 3]), 'float32')
12-
: tf.tensor4d(input, [1, 150, 150, 3])
13-
return tf.div(tf.sub(x, avg_rgb), tf.fill(x.shape, 256))
10+
return tf.div(tf.sub(x, avg_rgb), tf.scalar(256))
1411
})
1512
}

src/index.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1-
import * as tf from '@tensorflow/tfjs-core';
2-
31
import { euclideanDistance } from './euclideanDistance';
42
import { faceDetectionNet } from './faceDetectionNet';
53
import { faceRecognitionNet } from './faceRecognitionNet';
6-
import { normalize } from './normalize';
4+
import { NetInput } from './NetInput';
75

86
export {
97
euclideanDistance,
108
faceDetectionNet,
119
faceRecognitionNet,
12-
normalize,
13-
tf
10+
NetInput
1411
}
1512

1613
export * from './utils'

src/transformInputs.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { NetInput } from './NetInput';
4+
import { TNetInput } from './types';
5+
6+
export function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
7+
return tf.tidy(() => {
8+
9+
const [_, height, width] = imgTensor.shape
10+
if (height === width) {
11+
return imgTensor
12+
}
13+
14+
if (height > width) {
15+
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
16+
return tf.concat([imgTensor, pad], 2)
17+
}
18+
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
19+
return tf.concat([imgTensor, pad], 1)
20+
})
21+
}
22+
23+
export function getImageTensor(input: tf.Tensor | NetInput | TNetInput): tf.Tensor4D {
24+
return tf.tidy(() => {
25+
if (input instanceof tf.Tensor) {
26+
const rank = input.shape.length
27+
if (rank !== 3 && rank !== 4) {
28+
throw new Error('input tensor must be of rank 3 or 4')
29+
}
30+
31+
return (rank === 3 ? input.expandDims(0) : input).toFloat() as tf.Tensor4D
32+
}
33+
34+
const netInput = input instanceof NetInput ? input : new NetInput(input)
35+
return tf.concat(
36+
netInput.canvases.map(canvas =>
37+
tf.fromPixels(canvas).expandDims(0).toFloat()
38+
)
39+
) as tf.Tensor4D
40+
})
41+
}

src/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export type TMediaElement = HTMLImageElement | HTMLVideoElement | HTMLCanvasElement
2+
export type TNetInputArg = string | TMediaElement
3+
export type TNetInput = TNetInputArg | Array<TNetInputArg>

0 commit comments

Comments
 (0)