@@ -8,7 +8,7 @@ import { FaceDetection } from '../FaceDetection';
8
8
import { NetInput } from '../NetInput' ;
9
9
import { toNetInput } from '../toNetInput' ;
10
10
import { TNetInput } from '../types' ;
11
- import { BOX_ANCHORS , INPUT_SIZES , IOU_THRESHOLD , NUM_BOXES , NUM_CELLS } from './config' ;
11
+ import { BOX_ANCHORS , INPUT_SIZES , IOU_THRESHOLD , NUM_BOXES } from './config' ;
12
12
import { convWithBatchNorm } from './convWithBatchNorm' ;
13
13
import { extractParams } from './extractParams' ;
14
14
import { getDefaultParams } from './getDefaultParams' ;
@@ -59,18 +59,19 @@ export class TinyYolov2 extends NeuralNetwork<NetParams> {
59
59
60
60
public async locateFaces ( input : TNetInput , forwardParams : TinyYolov2ForwardParams = { } ) : Promise < FaceDetection [ ] > {
61
61
62
- const { sizeType , scoreThreshold } = getDefaultParams ( forwardParams )
62
+ const { inputSize : _inputSize , scoreThreshold } = getDefaultParams ( forwardParams )
63
63
64
+ const inputSize = typeof _inputSize === 'string'
65
+ ? INPUT_SIZES [ _inputSize ]
66
+ : _inputSize
64
67
65
- const inputSize = INPUT_SIZES [ sizeType ]
66
- const numCells = NUM_CELLS [ sizeType ]
67
-
68
- if ( ! inputSize ) {
69
- throw new Error ( `TinyYolov2 - unkown sizeType: ${ sizeType } , expected one of: xs | sm | md | lg` )
68
+ if ( typeof inputSize !== 'number' ) {
69
+ throw new Error ( `TinyYolov2 - unkown inputSize: ${ inputSize } , expected number or one of xs | sm | md | lg` )
70
70
}
71
71
72
72
const netInput = await toNetInput ( input , true )
73
73
const out = await this . forwardInput ( netInput , inputSize )
74
+ const numCells = out . shape [ 1 ]
74
75
75
76
const [ boxesTensor , scoresTensor ] = tf . tidy ( ( ) => {
76
77
const reshaped = out . reshape ( [ numCells , numCells , NUM_BOXES , 6 ] )
0 commit comments