Skip to content

Commit 3ca0f4a

Browse files
pad input to square
1 parent 89e9691 commit 3ca0f4a

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

src/faceDetectionNet/index.ts

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ function fromImageData(input: ImageData[]) {
3333
return tf.cast(tf.concat(imgTensors, 0), 'float32')
3434
}
3535

36+
function padToSquare(imgTensor: tf.Tensor4D): tf.Tensor4D {
37+
const [_, height, width] = imgTensor.shape
38+
if (height === width) {
39+
return imgTensor
40+
}
41+
42+
if (height > width) {
43+
const pad = tf.fill([1, height, height - width, 3], 0) as tf.Tensor4D
44+
return tf.concat([imgTensor, pad], 2)
45+
}
46+
const pad = tf.fill([1, width - height, width, 3], 0) as tf.Tensor4D
47+
return tf.concat([imgTensor, pad], 1)
48+
}
49+
3650
function getImgTensor(input: ImageData|ImageData[]|number[]) {
3751
return tf.tidy(() => {
3852

@@ -44,9 +58,11 @@ function getImgTensor(input: ImageData|ImageData[]|number[]) {
4458
: null
4559
)
4660

47-
return imgDataArray !== null
48-
? fromImageData(imgDataArray)
49-
: fromData(input as number[])
61+
return padToSquare(
62+
imgDataArray !== null
63+
? fromImageData(imgDataArray)
64+
: fromData(input as number[])
65+
)
5066

5167
})
5268
}
@@ -71,7 +87,7 @@ export function faceDetectionNet(weights: Float32Array) {
7187

7288
function forward(input: ImageData|ImageData[]|number[]) {
7389
return tf.tidy(
74-
() => forwardTensor(getImgTensor(input))
90+
() => forwardTensor(padToSquare(getImgTensor(input)))
7591
)
7692
}
7793

@@ -81,7 +97,6 @@ export function faceDetectionNet(weights: Float32Array) {
8197
maxResults: number = 100,
8298
): Promise<FaceDetectionNet.Detection[]> {
8399
const imgTensor = getImgTensor(input)
84-
85100
const [_, height, width] = imgTensor.shape
86101

87102
const {

src/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import * as tf from '@tensorflow/tfjs-core';
2+
13
import { euclideanDistance } from './euclideanDistance';
24
import { faceDetectionNet } from './faceDetectionNet';
35
import { faceRecognitionNet } from './faceRecognitionNet';
@@ -7,7 +9,8 @@ export {
79
euclideanDistance,
810
faceDetectionNet,
911
faceRecognitionNet,
10-
normalize
12+
normalize,
13+
tf
1114
}
1215

1316
export * from './utils'

src/utils.ts

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@ export function round(num: number) {
2323
return Math.floor(num * 100) / 100
2424
}
2525

26+
export type Dimensions = {
27+
width: number
28+
height: number
29+
}
30+
2631
export function drawMediaToCanvas(
2732
canvasArg: string | HTMLCanvasElement,
28-
mediaArg: string | HTMLImageElement | HTMLVideoElement
33+
mediaArg: string | HTMLImageElement | HTMLVideoElement,
34+
dims?: Dimensions
2935
): CanvasRenderingContext2D {
3036
const canvas = getElement(canvasArg)
3137
const media = getElement(mediaArg)
@@ -37,21 +43,24 @@ export function drawMediaToCanvas(
3743
throw new Error('drawMediaToCanvas - expected media to be of type: HTMLImageElement | HTMLVideoElement')
3844
}
3945

40-
canvas.width = media.width
41-
canvas.height = media.height
46+
const { width, height } = dims || media
47+
canvas.width = width
48+
canvas.height = height
4249

4350
const ctx = getContext2dOrThrow(canvas)
44-
ctx.drawImage(media, 0, 0, media.width, media.height)
51+
ctx.drawImage(media, 0, 0, width, height)
4552
return ctx
4653
}
4754

48-
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement): ImageData {
55+
export function mediaToImageData(media: HTMLImageElement | HTMLVideoElement, dims?: Dimensions): ImageData {
4956
if (!(media instanceof HTMLImageElement || media instanceof HTMLVideoElement)) {
5057
throw new Error('mediaToImageData - expected media to be of type: HTMLImageElement | HTMLVideoElement')
5158
}
5259

5360
const ctx = drawMediaToCanvas(document.createElement('canvas'), media)
54-
return ctx.getImageData(0, 0, media.width, media.height)
61+
62+
const { width, height } = dims || media
63+
return ctx.getImageData(0, 0, width, height)
5564
}
5665

5766
export function mediaSrcToImageData(

0 commit comments

Comments
 (0)