Skip to content

Commit 554bbe0

Browse files
pad input of face landmark net to square and center to avoid stretching of non square images
1 parent fd69881 commit 554bbe0

File tree

3 files changed

+78
-21
lines changed

3 files changed

+78
-21
lines changed

src/Rect.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,20 @@ export class Rect implements IRect {
1818
this.height = height
1919
}
2020

21+
public toSquare(): Rect {
22+
let { x, y, width, height } = this
23+
const diff = Math.abs(width - height)
24+
if (width < height) {
25+
x -= (diff / 2)
26+
width += diff
27+
}
28+
if (height < width) {
29+
y -= (diff / 2)
30+
height += diff
31+
}
32+
return new Rect(x, y, width, height)
33+
}
34+
2135
public floor(): Rect {
2236
return new Rect(
2337
Math.floor(this.x),

src/faceLandmarkNet/FaceLandmarkNet.ts

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import { convLayer } from '../commons/convLayer';
44
import { getImageTensor } from '../commons/getImageTensor';
55
import { ConvParams } from '../commons/types';
66
import { NetInput } from '../NetInput';
7+
import { padToSquare } from '../padToSquare';
78
import { Point } from '../Point';
89
import { toNetInput } from '../toNetInput';
910
import { Dimensions, TNetInput } from '../types';
11+
import { isEven } from '../utils';
1012
import { extractParams } from './extractParams';
1113
import { FaceLandmarks } from './FaceLandmarks';
1214
import { fullyConnectedLayer } from './fullyConnectedLayer';
@@ -41,31 +43,25 @@ export class FaceLandmarkNet {
4143
this._params = extractParams(weights)
4244
}
4345

44-
public async detectLandmarks(input: tf.Tensor | NetInput | TNetInput) {
45-
if (!this._params) {
46+
public forwardTensor(imgTensor: tf.Tensor4D): tf.Tensor2D {
47+
const params = this._params
48+
49+
if (!params) {
4650
throw new Error('FaceLandmarkNet - load model before inference')
4751
}
4852

49-
const netInput = input instanceof tf.Tensor
50-
? input
51-
: await toNetInput(input)
52-
53-
let imageDimensions: Dimensions | undefined
54-
55-
const outTensor = tf.tidy(() => {
56-
const params = this._params
57-
58-
let imgTensor = getImageTensor(netInput)
59-
const [height, width] = imgTensor.shape.slice(1)
60-
imageDimensions = { width, height }
53+
return tf.tidy(() => {
54+
const [batchSize, height, width] = imgTensor.shape.slice()
6155

56+
let x = padToSquare(imgTensor, true)
57+
const [heightAfterPadding, widthAfterPadding] = x.shape.slice(1)
6258

6359
// work with 128 x 128 sized face images
64-
if (imgTensor.shape[1] !== 128 || imgTensor.shape[2] !== 128) {
65-
imgTensor = tf.image.resizeBilinear(imgTensor, [128, 128])
60+
if (heightAfterPadding !== 128 || widthAfterPadding !== 128) {
61+
x = tf.image.resizeBilinear(x, [128, 128])
6662
}
6763

68-
let out = conv(imgTensor, params.conv0_params)
64+
let out = conv(x, params.conv0_params)
6965
out = maxPool(out)
7066
out = conv(out, params.conv1_params)
7167
out = conv(out, params.conv2_params)
@@ -80,14 +76,58 @@ export class FaceLandmarkNet {
8076
const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params))
8177
const fc1 = fullyConnectedLayer(fc0, params.fc1_params)
8278

83-
return fc1
79+
80+
const createInterleavedTensor = (fillX: number, fillY: number) =>
81+
tf.stack([
82+
tf.fill([68], fillX),
83+
tf.fill([68], fillY)
84+
], 1).as2D(batchSize, 136)
85+
86+
87+
/* shift coordinates back, to undo centered padding
88+
((x * widthAfterPadding) - shiftX) / width
89+
((y * heightAfterPadding) - shiftY) / height
90+
*/
91+
const shiftX = Math.floor(Math.abs(widthAfterPadding - width) / 2)
92+
const shiftY = Math.floor(Math.abs(heightAfterPadding - height) / 2)
93+
const landmarkTensor = fc1
94+
.mul(createInterleavedTensor(widthAfterPadding, heightAfterPadding))
95+
.sub(createInterleavedTensor(shiftX, shiftY))
96+
.div(createInterleavedTensor(width, height))
97+
98+
return landmarkTensor as tf.Tensor2D
99+
})
100+
}
101+
102+
public async forward(input: tf.Tensor | NetInput | TNetInput): Promise<tf.Tensor2D> {
103+
const netInput = input instanceof tf.Tensor
104+
? input
105+
: await toNetInput(input)
106+
107+
return this.forwardTensor(getImageTensor(netInput))
108+
}
109+
110+
public async detectLandmarks(input: tf.Tensor | NetInput | TNetInput) {
111+
const netInput = input instanceof tf.Tensor
112+
? input
113+
: await toNetInput(input)
114+
115+
let imageDimensions: Dimensions | undefined
116+
117+
const outTensor = tf.tidy(() => {
118+
const imgTensor = getImageTensor(netInput)
119+
120+
const [height, width] = imgTensor.shape.slice(1)
121+
imageDimensions = { width, height }
122+
123+
return this.forwardTensor(imgTensor)
84124
})
85125

86126
const faceLandmarksArray = Array.from(await outTensor.data())
87127
outTensor.dispose()
88128

89-
const xCoords = faceLandmarksArray.filter((c, i) => (i - 1) % 2)
90-
const yCoords = faceLandmarksArray.filter((c, i) => i % 2)
129+
const xCoords = faceLandmarksArray.filter((_, i) => isEven(i))
130+
const yCoords = faceLandmarksArray.filter((_, i) => !isEven(i))
91131

92132
return new FaceLandmarks(
93133
Array(68).fill(0).map((_, i) => new Point(xCoords[i], yCoords[i])),

test/tests/e2e/faceLandmarkNet.test.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ describe('faceLandmarkNet', () => {
3131
expect(result.getImageHeight()).toEqual(height)
3232
expect(result.getShift().x).toEqual(0)
3333
expect(result.getShift().y).toEqual(0)
34-
expect(result.getPositions().map(({ x, y }) => ({ x, y }))).toEqual(faceLandmarkPositions)
34+
result.getPositions().forEach(({ x, y }, i) => {
35+
expectMaxDelta(x, faceLandmarkPositions[i].x, 0.1)
36+
expectMaxDelta(y, faceLandmarkPositions[i].y, 0.1)
37+
})
3538
})
3639

3740
})

0 commit comments

Comments
 (0)