Skip to content

Commit ec095c0

Browse files
final landmark net implementation
1 parent 72d280c commit ec095c0

File tree

7 files changed

+166
-47
lines changed

7 files changed

+166
-47
lines changed

src/NetInput.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Dimensions, TMediaElement, TNetInput } from './types';
2-
import { createCanvas, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
2+
import { createCanvasFromMedia, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
33

44
export class NetInput {
55
private _canvases: HTMLCanvasElement[]
@@ -40,11 +40,8 @@ export class NetInput {
4040
}
4141

4242
// if input is batch type, make sure every canvas has the same dimensions
43-
const { width, height } = this.dims || dims || getMediaDimensions(media)
44-
45-
const canvas = createCanvas({ width, height })
46-
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
47-
this._canvases.push(canvas)
43+
const canvasDims = this.dims || dims
44+
this._canvases.push(createCanvasFromMedia(media, canvasDims))
4845
}
4946

5047
public get canvases() : HTMLCanvasElement[] {

src/Point.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export class Point {
2+
public x: number
3+
public y: number
4+
5+
constructor(x: number, y: number) {
6+
this.x = x
7+
this.y = y
8+
}
9+
}

src/commons/convLayer.ts

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ import { ConvParams } from './types';
55
export function convLayer(
66
x: tf.Tensor4D,
77
params: ConvParams,
8-
padding: 'valid' | 'same' = 'same'
8+
padding: 'valid' | 'same' = 'same',
9+
withRelu: boolean = false
910
): tf.Tensor4D {
10-
return tf.tidy(() =>
11-
tf.add(
11+
return tf.tidy(() => {
12+
const out = tf.add(
1213
tf.conv2d(x, params.filters, [1, 1], padding),
1314
params.bias
14-
)
15-
)
15+
) as tf.Tensor4D
16+
17+
return withRelu ? tf.relu(out) : out
18+
})
1619
}

src/faceLandmarkNet/FaceLandmarks.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { Point } from '../Point';
2+
import { Dimensions } from '../types';
3+
4+
export class FaceLandmarks {
5+
private _faceLandmarks: Point[]
6+
private _imageWidth: number
7+
private _imageHeight: number
8+
9+
constructor(
10+
relativeFaceLandmarkPositions: Point[],
11+
imageDims: Dimensions
12+
) {
13+
const { width, height } = imageDims
14+
this._imageWidth = width
15+
this._imageHeight = height
16+
this._faceLandmarks = relativeFaceLandmarkPositions.map(
17+
pt => new Point(pt.x * width, pt.y * height)
18+
)
19+
}
20+
21+
public getPositions() {
22+
return this._faceLandmarks
23+
}
24+
25+
public getRelativePositions() {
26+
return this._faceLandmarks.map(
27+
pt => new Point(pt.x / this._imageWidth, pt.y / this._imageHeight)
28+
)
29+
}
30+
31+
public forSize(width: number, height: number): FaceLandmarks {
32+
return new FaceLandmarks(this.getRelativePositions(), { width, height })
33+
}
34+
}

src/faceLandmarkNet/index.ts

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,75 @@
11
import * as tf from '@tensorflow/tfjs-core';
22

3+
import { convLayer } from '../commons/convLayer';
4+
import { ConvParams } from '../commons/types';
35
import { getImageTensor } from '../getImageTensor';
46
import { NetInput } from '../NetInput';
57
import { padToSquare } from '../padToSquare';
6-
import { TNetInput } from '../types';
8+
import { Dimensions, TNetInput } from '../types';
79
import { extractParams } from './extractParams';
8-
import { convLayer } from '../commons/convLayer';
10+
import { FaceLandmarks } from './FaceLandmarks';
911
import { fullyConnectedLayer } from './fullyConnectedLayer';
1012

13+
function conv(x: tf.Tensor4D, params: ConvParams): tf.Tensor4D {
14+
return convLayer(x, params, 'valid', true)
15+
}
16+
17+
function maxPool(x: tf.Tensor4D, strides: [number, number] = [2, 2]): tf.Tensor4D {
18+
return tf.maxPool(x, [2, 2], strides, 'valid')
19+
}
20+
1121
export function faceLandmarkNet(weights: Float32Array) {
1222
const params = extractParams(weights)
1323

14-
function forward(input: tf.Tensor | NetInput | TNetInput) {
15-
return tf.tidy(() => {
24+
async function detectLandmarks(input: tf.Tensor | NetInput | TNetInput) {
25+
let adjustRelativeX = 0
26+
let adjustRelativeY = 0
27+
let imageDimensions: Dimensions | undefined
28+
29+
const outTensor = tf.tidy(() => {
30+
let imgTensor = getImageTensor(input)
31+
const [height, width] = imgTensor.shape.slice(1)
32+
imageDimensions = { width, height }
33+
34+
imgTensor = padToSquare(imgTensor, true)
35+
adjustRelativeX = (height > width) ? imgTensor.shape[2] / (2 * width) : 0
36+
adjustRelativeY = (width > height) ? imgTensor.shape[1] / (2 * height) : 0
1637

17-
let x = padToSquare(getImageTensor(input), true)
1838
// work with 128 x 128 sized face images
19-
if (x.shape[1] !== 128 || x.shape[2] !== 128) {
20-
x = tf.image.resizeBilinear(x, [128, 128])
39+
if (imgTensor.shape[1] !== 128 || imgTensor.shape[2] !== 128) {
40+
imgTensor = tf.image.resizeBilinear(imgTensor, [128, 128])
2141
}
2242

23-
let out = convLayer(x, params.conv0_params, 'valid')
24-
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
25-
out = convLayer(out, params.conv1_params, 'valid')
26-
out = convLayer(out, params.conv2_params, 'valid')
27-
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
28-
out = convLayer(out, params.conv3_params, 'valid')
29-
out = convLayer(out, params.conv4_params, 'valid')
30-
out = tf.maxPool(out, [2, 2], [2, 2], 'valid')
31-
out = convLayer(out, params.conv5_params, 'valid')
32-
out = convLayer(out, params.conv6_params, 'valid')
33-
out = tf.maxPool(out, [2, 2], [1, 1], 'valid')
34-
out = convLayer(out, params.conv7_params, 'valid')
35-
const fc0 = fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params)
43+
let out = conv(imgTensor, params.conv0_params)
44+
out = maxPool(out)
45+
out = conv(out, params.conv1_params)
46+
out = conv(out, params.conv2_params)
47+
out = maxPool(out)
48+
out = conv(out, params.conv3_params)
49+
out = conv(out, params.conv4_params)
50+
out = maxPool(out)
51+
out = conv(out, params.conv5_params)
52+
out = conv(out, params.conv6_params)
53+
out = maxPool(out, [1, 1])
54+
out = conv(out, params.conv7_params)
55+
const fc0 = tf.relu(fullyConnectedLayer(out.as2D(out.shape[0], -1), params.fc0_params))
3656
const fc1 = fullyConnectedLayer(fc0, params.fc1_params)
3757

3858
return fc1
3959
})
60+
61+
const faceLandmarksArray = Array.from(await outTensor.data())
62+
const xCoords = faceLandmarksArray.filter((c, i) => (i - 1) % 2).map(x => x + adjustRelativeX)
63+
const yCoords = faceLandmarksArray.filter((c, i) => i % 2).map(y => y + adjustRelativeY)
64+
outTensor.dispose()
65+
66+
return new FaceLandmarks(
67+
Array(68).fill(0).map((_, i) => ({ x: xCoords[i], y: yCoords[i] })),
68+
imageDimensions as Dimensions
69+
)
4070
}
4171

4272
return {
43-
forward
73+
detectLandmarks
4474
}
4575
}

src/types.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,25 @@ export type Dimensions = {
1010
}
1111

1212
export type DrawBoxOptions = {
13-
lineWidth: number
14-
color: string
13+
lineWidth?: number
14+
color?: string
1515
}
1616

1717
export type DrawTextOptions = {
18+
lineWidth?: number
19+
fontSize?: number
20+
fontStyle?: string
21+
color?: string
22+
}
23+
24+
export type DrawLandmarksOptions = {
25+
lineWidth?: number
26+
color?: string
27+
}
28+
29+
export type DrawOptions = {
1830
lineWidth: number
1931
fontSize: number
2032
fontStyle: string
2133
color: string
22-
}
23-
24-
export type DrawOptions = DrawBoxOptions & DrawTextOptions
34+
}

src/utils.ts

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { FaceDetectionNet } from './faceDetectionNet/types';
2-
import { Dimensions, DrawBoxOptions, DrawOptions, DrawTextOptions } from './types';
2+
import { FaceLandmarks } from './faceLandmarkNet/FaceLandmarks';
3+
import { Dimensions, DrawBoxOptions, DrawLandmarksOptions, DrawOptions, DrawTextOptions } from './types';
34

45
export function isFloat(num: number) {
56
return num % 1 !== 0
@@ -24,16 +25,17 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC
2425
return ctx
2526
}
2627

27-
export function createCanvas({ width, height}: Dimensions): HTMLCanvasElement {
28+
export function createCanvas({ width, height }: Dimensions): HTMLCanvasElement {
2829
const canvas = document.createElement('canvas')
2930
canvas.width = width
3031
canvas.height = height
3132
return canvas
3233
}
3334

34-
export function createCanvasWithImageData({ width, height}: Dimensions, buf: Uint8ClampedArray): HTMLCanvasElement {
35+
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): HTMLCanvasElement {
36+
const { width, height } = dims || getMediaDimensions(media)
3537
const canvas = createCanvas({ width, height })
36-
getContext2dOrThrow(canvas).putImageData(new ImageData(buf, width, height), 0, 0)
38+
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
3739
return canvas
3840
}
3941

@@ -82,8 +84,13 @@ export function drawBox(
8284
h: number,
8385
options: DrawBoxOptions
8486
) {
85-
ctx.strokeStyle = options.color
86-
ctx.lineWidth = options.lineWidth
87+
const drawOptions = Object.assign(
88+
getDefaultDrawOptions(),
89+
(options || {})
90+
)
91+
92+
ctx.strokeStyle = drawOptions.color
93+
ctx.lineWidth = drawOptions.lineWidth
8794
ctx.strokeRect(x, y, w, h)
8895
}
8996

@@ -94,11 +101,16 @@ export function drawText(
94101
text: string,
95102
options: DrawTextOptions
96103
) {
97-
const padText = 2 + options.lineWidth
104+
const drawOptions = Object.assign(
105+
getDefaultDrawOptions(),
106+
(options || {})
107+
)
108+
109+
const padText = 2 + drawOptions.lineWidth
98110

99-
ctx.fillStyle = options.color
100-
ctx.font = `${options.fontSize}px ${options.fontStyle}`
101-
ctx.fillText(text, x + padText, y + padText + (options.fontSize * 0.6))
111+
ctx.fillStyle = drawOptions.color
112+
ctx.font = `${drawOptions.fontSize}px ${drawOptions.fontStyle}`
113+
ctx.fillText(text, x + padText, y + padText + (drawOptions.fontSize * 0.6))
102114
}
103115

104116
export function drawDetection(
@@ -154,4 +166,28 @@ export function drawDetection(
154166
)
155167
}
156168
})
169+
}
170+
171+
export function drawLandmarks(
172+
canvasArg: string | HTMLCanvasElement,
173+
faceLandmarks: FaceLandmarks,
174+
options?: DrawLandmarksOptions & { drawLines: boolean }
175+
) {
176+
const canvas = getElement(canvasArg)
177+
if (!(canvas instanceof HTMLCanvasElement)) {
178+
throw new Error('drawLandmarks - expected canvas to be of type: HTMLCanvasElement')
179+
}
180+
181+
const drawOptions = Object.assign(
182+
getDefaultDrawOptions(),
183+
(options || {})
184+
)
185+
186+
const { drawLines } = Object.assign({ drawLines: false }, (options || {}))
187+
188+
const ctx = getContext2dOrThrow(canvas)
189+
const { lineWidth,color } = drawOptions
190+
ctx.fillStyle = color
191+
const ptOffset = lineWidth / 2
192+
faceLandmarks.getPositions().forEach(pt => ctx.fillRect(pt.x - ptOffset, pt.y - ptOffset, lineWidth, lineWidth))
157193
}

0 commit comments

Comments
 (0)