Skip to content

Commit cdd2c49

Browse files
extract methods for face tensors and face canvas + pad and resize face recognition net input to 150x150 + some fixes
1 parent f7c9038 commit cdd2c49

File tree

9 files changed

+150
-53
lines changed

9 files changed

+150
-53
lines changed

src/NetInput.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { Dimensions, TMediaElement, TNetInput } from './types';
2-
import { getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
2+
import { createCanvas, getContext2dOrThrow, getElement, getMediaDimensions } from './utils';
33

44
export class NetInput {
55
private _canvases: HTMLCanvasElement[]
@@ -42,10 +42,7 @@ export class NetInput {
4242
// if input is batch type, make sure every canvas has the same dimensions
4343
const { width, height } = this.dims || dims || getMediaDimensions(media)
4444

45-
const canvas = document.createElement('canvas')
46-
canvas.width = width
47-
canvas.height = height
48-
45+
const canvas = createCanvas({ width, height })
4946
getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
5047
this._canvases.push(canvas)
5148
}

src/extractFaceTensors.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
3+
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
4+
import { NetInput } from './NetInput';
5+
import { getImageTensor } from './transformInputs';
6+
import { TNetInput } from './types';
7+
8+
/**
9+
* Extracts the tensors of the image regions containing the detected faces.
10+
* Returned tensors have to be disposed manually once you don't need them anymore!
11+
* Useful if you want to compute the face descriptors for the face
12+
* images. Using this method is faster then extracting a canvas for each face and
13+
* converting them to tensors individually.
14+
*
15+
* @param input The image that face detection has been performed on.
16+
* @param detections The face detection results for that image.
17+
* @returns Tensors of the corresponding image region for each detected face.
18+
*/
19+
export function extractFaceTensors(
20+
image: tf.Tensor | NetInput | TNetInput,
21+
detections: FaceDetectionResult[]
22+
): tf.Tensor4D[] {
23+
return tf.tidy(() => {
24+
const imgTensor = getImageTensor(image)
25+
26+
// TODO handle batches
27+
const [batchSize, imgHeight, imgWidth, numChannels] = imgTensor.shape
28+
29+
const faceTensors = detections.map(det => {
30+
const { x, y, width, height } = det.forSize(imgWidth, imgHeight).box
31+
return tf.slice(imgTensor, [0, y, x, 0], [1, height, width, numChannels])
32+
})
33+
34+
return faceTensors
35+
})
36+
}

src/extractFaces.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { FaceDetectionResult } from './faceDetectionNet/FaceDetectionResult';
2+
import { createCanvas, getContext2dOrThrow } from './utils';
3+
4+
/**
5+
* Extracts the image regions containing the detected faces.
6+
*
7+
* @param input The image that face detection has been performed on.
8+
* @param detections The face detection results for that image.
9+
* @returns The Canvases of the corresponding image region for each detected face.
10+
*/
11+
export function extractFaces(
12+
image: HTMLCanvasElement,
13+
detections: FaceDetectionResult[]
14+
): HTMLCanvasElement[] {
15+
const ctx = getContext2dOrThrow(image)
16+
17+
return detections.map(det => {
18+
const { x, y, width, height } = det.forSize(image.width, image.height).box
19+
20+
const faceImg = createCanvas({ width, height })
21+
getContext2dOrThrow(faceImg)
22+
.putImageData(ctx.getImageData(x, y, width, height), 0, 0)
23+
return faceImg
24+
})
25+
}

src/faceDetectionNet/FaceDetectionResult.ts

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
import { FaceDetectionNet } from './types';
22

33
export class FaceDetectionResult {
4-
private score: number
5-
private top: number
6-
private left: number
7-
private bottom: number
8-
private right: number
4+
private _score: number
5+
private _topRelative: number
6+
private _leftRelative: number
7+
private _bottomRelative: number
8+
private _rightRelative: number
99

1010
constructor(
1111
score: number,
12-
top: number,
13-
left: number,
14-
bottom: number,
15-
right: number
12+
topRelative: number,
13+
leftRelative: number,
14+
bottomRelative: number,
15+
rightRelative: number
1616
) {
17-
this.score = score
18-
this.top = Math.max(0, top),
19-
this.left = Math.max(0, left),
20-
this.bottom = Math.min(1.0, bottom),
21-
this.right = Math.min(1.0, right)
17+
this._score = score
18+
this._topRelative = Math.max(0, topRelative),
19+
this._leftRelative = Math.max(0, leftRelative),
20+
this._bottomRelative = Math.min(1.0, bottomRelative),
21+
this._rightRelative = Math.min(1.0, rightRelative)
2222
}
2323

2424
public forSize(width: number, height: number): FaceDetectionNet.Detection {
25+
const x = Math.floor(this._leftRelative * width)
26+
const y = Math.floor(this._topRelative * height)
2527
return {
26-
score: this.score,
28+
score: this._score,
2729
box: {
28-
top: this.top * height,
29-
left: this.left * width,
30-
bottom: this.bottom * height,
31-
right: this.right * width
30+
x,
31+
y,
32+
width: Math.floor(this._rightRelative * width) - x,
33+
height: Math.floor(this._bottomRelative * height) - y
3234
}
3335
}
3436
}

src/faceDetectionNet/types.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ export namespace FaceDetectionNet {
6969
export type Detection = {
7070
score: number
7171
box: {
72-
top: number,
73-
left: number,
74-
right: number,
75-
bottom: number
72+
x: number,
73+
y: number,
74+
width: number,
75+
height: number
7676
}
7777
}
7878

src/faceRecognitionNet/index.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@ export function faceRecognitionNet(weights: Float32Array) {
1414
function forward(input: tf.Tensor | NetInput | TNetInput) {
1515
return tf.tidy(() => {
1616

17-
const x = normalize(padToSquare(getImageTensor(input)))
17+
// TODO pad on both sides, to keep face centered
18+
let x = padToSquare(getImageTensor(input))
19+
// work with 150 x 150 sized face images
20+
if (x.shape[1] !== 150 || x.shape[2] !== 150) {
21+
x = tf.image.resizeBilinear(x, [150, 150])
22+
}
23+
x = normalize(x)
1824

1925
let out = convDown(x, params.conv32_down)
2026
out = tf.maxPool(out, 3, 2, 'valid')

src/index.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ import { euclideanDistance } from './euclideanDistance';
22
import { faceDetectionNet } from './faceDetectionNet';
33
import { faceRecognitionNet } from './faceRecognitionNet';
44
import { NetInput } from './NetInput';
5+
import * as tf from '@tensorflow/tfjs-core';
56

67
export {
78
euclideanDistance,
89
faceDetectionNet,
910
faceRecognitionNet,
10-
NetInput
11+
NetInput,
12+
tf
1113
}
1214

15+
export * from './extractFaces'
16+
export * from './extractFaceTensors'
1317
export * from './utils'

src/types.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ export type DrawBoxOptions = {
1515
}
1616

1717
export type DrawTextOptions = {
18+
lineWidth: number
1819
fontSize: number
1920
fontStyle: string
2021
color: string
21-
}
22+
}
23+
24+
export type DrawOptions = DrawBoxOptions & DrawTextOptions

src/utils.ts

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { FaceDetectionNet } from './faceDetectionNet/types';
2-
import { DrawBoxOptions, DrawTextOptions } from './types';
2+
import { Dimensions, DrawBoxOptions, DrawOptions, DrawTextOptions } from './types';
33

44
export function isFloat(num: number) {
55
return num % 1 !== 0
@@ -24,7 +24,23 @@ export function getContext2dOrThrow(canvas: HTMLCanvasElement): CanvasRenderingC
2424
return ctx
2525
}
2626

27+
export function createCanvas({ width, height}: Dimensions): HTMLCanvasElement {
28+
const canvas = document.createElement('canvas')
29+
canvas.width = width
30+
canvas.height = height
31+
return canvas
32+
}
33+
34+
export function createCanvasWithImageData({ width, height}: Dimensions, buf: Uint8ClampedArray): HTMLCanvasElement {
35+
const canvas = createCanvas({ width, height })
36+
getContext2dOrThrow(canvas).putImageData(new ImageData(buf, width, height), 0, 0)
37+
return canvas
38+
}
39+
2740
export function getMediaDimensions(media: HTMLImageElement | HTMLVideoElement) {
41+
if (media instanceof HTMLImageElement) {
42+
return { width: media.naturalWidth, height: media.naturalHeight }
43+
}
2844
if (media instanceof HTMLVideoElement) {
2945
return { width: media.videoWidth, height: media.videoHeight }
3046
}
@@ -49,6 +65,15 @@ export function bufferToImage(buf: Blob): Promise<HTMLImageElement> {
4965
})
5066
}
5167

68+
export function getDefaultDrawOptions(): DrawOptions {
69+
return {
70+
color: 'blue',
71+
lineWidth: 2,
72+
fontSize: 20,
73+
fontStyle: 'Georgia'
74+
}
75+
}
76+
5277
export function drawBox(
5378
ctx: CanvasRenderingContext2D,
5479
x: number,
@@ -69,9 +94,11 @@ export function drawText(
6994
text: string,
7095
options: DrawTextOptions
7196
) {
97+
const padText = 2 + options.lineWidth
98+
7299
ctx.fillStyle = options.color
73100
ctx.font = `${options.fontSize}px ${options.fontStyle}`
74-
ctx.fillText(text, x, y)
101+
ctx.fillText(text, x + padText, y + padText + (options.fontSize * 0.6))
75102
}
76103

77104
export function drawDetection(
@@ -95,38 +122,35 @@ export function drawDetection(
95122
} = det
96123

97124
const {
98-
left,
99-
right,
100-
top,
101-
bottom
125+
x,
126+
y,
127+
width,
128+
height
102129
} = box
103130

104-
const {
105-
color = 'blue',
106-
lineWidth = 2,
107-
fontSize = 20,
108-
fontStyle = 'Georgia',
109-
withScore = true
110-
} = (options || {})
131+
const drawOptions = Object.assign(
132+
getDefaultDrawOptions(),
133+
(options || {})
134+
)
111135

112-
const padText = 2 + lineWidth
136+
const { withScore } = Object.assign({ withScore: true }, (options || {}))
113137

114138
const ctx = getContext2dOrThrow(canvas)
115139
drawBox(
116140
ctx,
117-
left,
118-
top,
119-
right - left,
120-
bottom - top,
121-
{ lineWidth, color }
141+
x,
142+
y,
143+
width,
144+
height,
145+
drawOptions
122146
)
123147
if (withScore) {
124148
drawText(
125149
ctx,
126-
left + padText,
127-
top + (fontSize * 0.6) + padText,
150+
x,
151+
y,
128152
`${round(score)}`,
129-
{ fontSize, fontStyle, color }
153+
drawOptions
130154
)
131155
}
132156
})

0 commit comments

Comments
 (0)