Skip to content

Commit 1d133c5

Browse files
fixed MTCNN forwardParams initialization for optional params
1 parent 5d87492 commit 1d133c5

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

src/globalApi.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ export function computeFaceDescriptor(
7373

7474
export function mtcnn(
7575
input: TNetInput,
76-
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
76+
forwardParams: MtcnnForwardParams
7777
): Promise<MtcnnResult[]> {
78-
return nets.mtcnn.forward(input, forwardParameters)
78+
return nets.mtcnn.forward(input, forwardParams)
7979
}
8080

8181
export const allFaces: (

src/mtcnn/Mtcnn.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
2828

2929
public async forwardInput(
3030
input: NetInput,
31-
{ minFaceSize, scaleFactor, maxNumScales, scoreThresholds, scaleSteps } = getDefaultMtcnnForwardParams()
31+
forwardParams: MtcnnForwardParams
3232
): Promise<{ results: MtcnnResult[], stats: any }> {
3333

3434
const { params } = this
@@ -64,6 +64,14 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
6464

6565
const [height, width] = imgTensor.shape.slice(1)
6666

67+
const {
68+
minFaceSize,
69+
scaleFactor,
70+
maxNumScales,
71+
scoreThresholds,
72+
scaleSteps
73+
} = Object.assign({}, getDefaultMtcnnForwardParams(), forwardParams)
74+
6775
const scales = scaleSteps || pyramidDown(minFaceSize, scaleFactor, [height, width])
6876
.filter(scale => {
6977
const sizes = getSizesForScale(scale, [height, width])
@@ -124,23 +132,23 @@ export class Mtcnn extends NeuralNetwork<NetParams> {
124132

125133
public async forward(
126134
input: TNetInput,
127-
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
135+
forwardParams: MtcnnForwardParams
128136
): Promise<MtcnnResult[]> {
129137
return (
130138
await this.forwardInput(
131139
await toNetInput(input, true, true),
132-
forwardParameters
140+
forwardParams
133141
)
134142
).results
135143
}
136144

137145
public async forwardWithStats(
138146
input: TNetInput,
139-
forwardParameters: MtcnnForwardParams = getDefaultMtcnnForwardParams()
147+
forwardParams: MtcnnForwardParams
140148
): Promise<{ results: MtcnnResult[], stats: any }> {
141149
return this.forwardInput(
142150
await toNetInput(input, true, true),
143-
forwardParameters
151+
forwardParams
144152
)
145153
}
146154

src/mtcnn/getDefaultMtcnnForwardParams.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import { MtcnnForwardParams } from './types';
2-
3-
export function getDefaultMtcnnForwardParams(): MtcnnForwardParams {
1+
export function getDefaultMtcnnForwardParams() {
42
return {
53
minFaceSize: 20,
64
scaleFactor: 0.709,

src/mtcnn/types.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ export type MtcnnResult = {
4747
}
4848

4949
export type MtcnnForwardParams = {
50-
minFaceSize: number
51-
scaleFactor: number
52-
maxNumScales: number
53-
scoreThresholds: number[]
50+
minFaceSize?: number
51+
scaleFactor?: number
52+
maxNumScales?: number
53+
scoreThresholds?: number[]
5454
scaleSteps?: number[]
5555
}

0 commit comments

Comments
 (0)