Skip to content

Commit e95dd35

Browse files
allow arbitrary input size
1 parent 28a7745 commit e95dd35

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

src/tinyYolov2/TinyYolov2.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import { FaceDetection } from '../FaceDetection';
88
import { NetInput } from '../NetInput';
99
import { toNetInput } from '../toNetInput';
1010
import { TNetInput } from '../types';
11-
import { BOX_ANCHORS, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES, NUM_CELLS } from './config';
11+
import { BOX_ANCHORS, INPUT_SIZES, IOU_THRESHOLD, NUM_BOXES } from './config';
1212
import { convWithBatchNorm } from './convWithBatchNorm';
1313
import { extractParams } from './extractParams';
1414
import { getDefaultParams } from './getDefaultParams';
@@ -59,18 +59,19 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
5959

6060
public async locateFaces(input: TNetInput, forwardParams: TinyYolov2ForwardParams = {}): Promise<FaceDetection[]> {
6161

62-
const { sizeType, scoreThreshold } = getDefaultParams(forwardParams)
62+
const { inputSize: _inputSize, scoreThreshold } = getDefaultParams(forwardParams)
6363

64+
const inputSize = typeof _inputSize === 'string'
65+
? INPUT_SIZES[_inputSize]
66+
: _inputSize
6467

65-
const inputSize = INPUT_SIZES[sizeType]
66-
const numCells = NUM_CELLS[sizeType]
67-
68-
if (!inputSize) {
69-
throw new Error(`TinyYolov2 - unkown sizeType: ${sizeType}, expected one of: xs | sm | md | lg`)
68+
if (typeof inputSize !== 'number') {
69+
throw new Error(`TinyYolov2 - unkown inputSize: ${inputSize}, expected number or one of xs | sm | md | lg`)
7070
}
7171

7272
const netInput = await toNetInput(input, true)
7373
const out = await this.forwardInput(netInput, inputSize)
74+
const numCells = out.shape[1]
7475

7576
const [boxesTensor, scoresTensor] = tf.tidy(() => {
7677
const reshaped = out.reshape([numCells, numCells, NUM_BOXES, 6])

src/tinyYolov2/config.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { Point } from '../Point';
22

33
export const INPUT_SIZES = { xs: 224, sm: 320, md: 416, lg: 608 }
4-
export const NUM_CELLS = { xs: 7, sm: 10, md: 13, lg: 19 }
54
export const NUM_BOXES = 5
65
export const IOU_THRESHOLD = 0.4
76

src/tinyYolov2/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,6 @@ export enum SizeType {
3232
}
3333

3434
export type TinyYolov2ForwardParams = {
35-
sizeType?: SizeType
35+
inputSize?: SizeType | number
3636
scoreThreshold?: number
3737
}

0 commit comments

Comments
 (0)