Skip to content

Commit 2b127e4

Browse files
add centering option to padToSquare and center input face of face recognition net
1 parent d7962a5 commit 2b127e4

File tree

6 files changed

+42
-25
lines changed

6 files changed

+42
-25
lines changed

src/extractFaceTensors.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';
22

33
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
44
import { NetInput } from './NetInput';
5-
import { getImageTensor } from './transformInputs';
5+
import { getImageTensor } from './getImageTensor';
66
import { TNetInput } from './types';
77

88
/**

src/faceDetectionNet/index.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { getImageTensor } from '../getImageTensor';
34
import { NetInput } from '../NetInput';
4-
import { getImageTensor, padToSquare } from '../transformInputs';
5+
import { padToSquare } from '../padToSquare';
56
import { TNetInput } from '../types';
67
import { extractParams } from './extractParams';
78
import { FaceDetectionResult } from './FaceDetectionResult';
@@ -49,7 +50,7 @@ export function faceDetectionNet(weights: Float32Array) {
4950
} = tf.tidy(() => {
5051

5152
let imgTensor = getImageTensor(input)
52-
const [_, height, width] = imgTensor.shape
53+
const [height, width] = imgTensor.shape.slice(1)
5354

5455
imgTensor = padToSquare(imgTensor)
5556
paddedHeightRelative = imgTensor.shape[1] / height

src/faceRecognitionNet/index.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { getImageTensor } from '../getImageTensor';
34
import { NetInput } from '../NetInput';
4-
import { getImageTensor, padToSquare } from '../transformInputs';
5+
import { padToSquare } from '../padToSquare';
56
import { TNetInput } from '../types';
67
import { convDown } from './convLayer';
78
import { extractParams } from './extractParams';
@@ -14,8 +15,7 @@ export function faceRecognitionNet(weights: Float32Array) {
1415
function forward(input: tf.Tensor | NetInput | TNetInput) {
1516
return tf.tidy(() => {
1617

17-
// TODO pad on both sides, to keep face centered
18-
let x = padToSquare(getImageTensor(input))
18+
let x = padToSquare(getImageTensor(input), true)
1919
// work with 150 x 150 sized face images
2020
if (x.shape[1] !== 150 || x.shape[2] !== 150) {
2121
x = tf.image.resizeBilinear(x, [150, 150])

src/transformInputs.ts renamed to src/getImageTensor.ts

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,6 @@ import * as tf from '@tensorflow/tfjs-core';
33
import { NetInput } from './NetInput';
44
import { TNetInput } from './types';
55

6-
export function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
7-
return tf.tidy(() => {
8-
9-
const [_, height, width] = imgTensor.shape
10-
if (height === width) {
11-
return imgTensor
12-
}
13-
14-
if (height > width) {
15-
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
16-
return tf.concat([imgTensor, pad], 2)
17-
}
18-
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
19-
return tf.concat([imgTensor, pad], 1)
20-
})
21-
}
22-
236
export function getImageTensor(input: tf.Tensor | NetInput | TNetInput): tf.Tensor4D {
247
return tf.tidy(() => {
258
if (input instanceof tf.Tensor) {

src/index.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
13
import { euclideanDistance } from './euclideanDistance';
24
import { faceDetectionNet } from './faceDetectionNet';
35
import { faceRecognitionNet } from './faceRecognitionNet';
46
import { NetInput } from './NetInput';
5-
import * as tf from '@tensorflow/tfjs-core';
7+
import { padToSquare } from './padToSquare';
68

79
export {
810
euclideanDistance,
911
faceDetectionNet,
1012
faceRecognitionNet,
1113
NetInput,
12-
tf
14+
tf,
15+
padToSquare
1316
}
1417

1518
export * from './extractFaces'

src/padToSquare.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
/**
4+
* Pads the smaller dimension of an image tensor with zeros, such that width === height.
5+
*
6+
* @param imgTensor The image tensor.
7+
* @param isCenterImage (optional, default: false) If true, add padding on both sides of the image, such that the image
8+
* @returns The padded tensor with width === height.
9+
*/
10+
export function padToSquare(
11+
imgTensor: tf.Tensor4D,
12+
isCenterImage: boolean = false
13+
): tf.Tensor4D {
14+
return tf.tidy(() => {
15+
16+
const [height, width] = imgTensor.shape.slice(1)
17+
if (height === width) {
18+
return imgTensor
19+
}
20+
21+
const paddingAmount = Math.floor(Math.abs(height - width) * (isCenterImage ? 0.5 : 1))
22+
const paddingAxis = height > width ? 2 : 1
23+
const paddingTensorShape = imgTensor.shape.slice() as [number, number, number, number]
24+
paddingTensorShape[paddingAxis] = paddingAmount
25+
26+
const tensorsToStack = (isCenterImage ? [tf.fill(paddingTensorShape, 0)] : [])
27+
.concat([imgTensor, tf.fill(paddingTensorShape, 0)]) as tf.Tensor4D[]
28+
return tf.concat(tensorsToStack, paddingAxis)
29+
})
30+
}

0 commit comments

Comments
 (0)