1
1
import * as tf from '@tensorflow/tfjs-core' ;
2
2
3
3
import { Point } from '../Point' ;
4
+ import { BoundingBox } from './BoundingBox' ;
4
5
import { CELL_SIZE , CELL_STRIDE } from './config' ;
6
+ import { nms } from './nms' ;
5
7
import { PNet } from './PNet' ;
6
8
import { PNetParams } from './types' ;
7
9
@@ -18,96 +20,116 @@ function rescaleAndNormalize(x: tf.Tensor4D, scale: number): tf.Tensor4D {
18
20
19
21
20
22
function extractBoundingBoxes (
21
- scores : tf . Tensor2D ,
22
- regions : tf . Tensor3D ,
23
+ scoresTensor : tf . Tensor2D ,
24
+ regionsTensor : tf . Tensor3D ,
23
25
scale : number ,
24
26
scoreThreshold : number
25
27
) {
26
28
27
29
// TODO: fix this!, maybe better to use tf.gather here
28
- const indices2D : Point [ ] = [ ]
29
- for ( let y = 0 ; y < scores . shape [ 0 ] ; y ++ ) {
30
- for ( let x = 0 ; x < scores . shape [ 1 ] ; x ++ ) {
31
- if ( scores . get ( y , x ) >= scoreThreshold ) {
32
- indices2D . push ( new Point ( x , y ) )
30
+ const indices : Point [ ] = [ ]
31
+ for ( let y = 0 ; y < scoresTensor . shape [ 0 ] ; y ++ ) {
32
+ for ( let x = 0 ; x < scoresTensor . shape [ 1 ] ; x ++ ) {
33
+ if ( scoresTensor . get ( y , x ) >= scoreThreshold ) {
34
+ indices . push ( new Point ( x , y ) )
33
35
}
34
36
}
35
37
}
36
38
37
- if ( ! indices2D . length ) {
38
- return null
39
- }
40
-
41
- return tf . tidy ( ( ) => {
42
-
43
- const indicesTensor = tf . tensor2d (
44
- indices2D . map ( pt => [ pt . y , pt . x ] ) ,
45
- [ indices2D . length , 2 ]
39
+ const boundingBoxes = indices . map ( idx => {
40
+ const cell = new BoundingBox (
41
+ Math . round ( ( idx . x * CELL_STRIDE + 1 ) / scale ) ,
42
+ Math . round ( ( idx . y * CELL_STRIDE + 1 ) / scale ) ,
43
+ Math . round ( ( idx . x * CELL_STRIDE + CELL_SIZE ) / scale ) ,
44
+ Math . round ( ( idx . y * CELL_STRIDE + CELL_SIZE ) / scale )
46
45
)
47
46
48
- const cellsStart = tf . round (
49
- indicesTensor . mul ( tf . scalar ( CELL_STRIDE ) ) . add ( tf . scalar ( 1 ) ) . div ( tf . scalar ( scale ) )
50
- ) as tf . Tensor2D
51
- const cellsEnd = tf . round (
52
- indicesTensor . mul ( tf . scalar ( CELL_STRIDE ) ) . add ( tf . scalar ( CELL_SIZE ) ) . div ( tf . scalar ( scale ) )
53
- ) as tf . Tensor2D
54
-
55
- const scoresTensor = tf . tensor1d ( indices2D . map ( pt => scores . get ( pt . y , pt . x ) ) )
56
-
57
- const candidateRegions = indices2D . map ( c => ( {
58
- left : regions . get ( c . y , c . x , 0 ) ,
59
- top : regions . get ( c . y , c . x , 1 ) ,
60
- right : regions . get ( c . y , c . x , 2 ) ,
61
- bottom : regions . get ( c . y , c . x , 3 )
62
- } ) )
63
-
64
- const regionsTensor = tf . tensor2d (
65
- candidateRegions . map ( r => [ r . left , r . top , r . right , r . bottom ] ) ,
66
- [ candidateRegions . length , 4 ]
67
- )
47
+ const score = scoresTensor . get ( idx . y , idx . x )
68
48
69
- const boxesTensor = tf . concat2d ( [ cellsStart , cellsEnd , scoresTensor . as2D ( scoresTensor . size , 1 ) , regionsTensor ] , 1 )
49
+ const region = new BoundingBox (
50
+ regionsTensor . get ( idx . y , idx . x , 0 ) ,
51
+ regionsTensor . get ( idx . y , idx . x , 1 ) ,
52
+ regionsTensor . get ( idx . y , idx . x , 2 ) ,
53
+ regionsTensor . get ( idx . y , idx . x , 3 )
54
+ )
70
55
71
- return boxesTensor
56
+ return {
57
+ cell,
58
+ score,
59
+ region
60
+ }
72
61
} )
73
- }
74
-
75
- // TODO: debug
76
- declare const window : any
77
62
78
- export function stage1 ( x : tf . Tensor4D , scales : number [ ] , scoreThreshold : number , params : PNetParams ) {
79
- return tf . tidy ( ( ) => {
63
+ return boundingBoxes
64
+ }
80
65
81
- const boxes = scales . map ( ( scale , i ) => {
82
- let resized = i === 0
83
- // TODO: debug
84
- ? tf . tensor4d ( window . resizedData , [ 1 , 820 , 461 , 3 ] )
66
+ export function stage1 ( imgTensor : tf . Tensor4D , scales : number [ ] , scoreThreshold : number , params : PNetParams ) {
85
67
86
- : rescaleAndNormalize ( x , scale )
68
+ const boxesForScale = scales . map ( ( scale ) => {
87
69
70
+ const { scoresTensor, regionsTensor } = tf . tidy ( ( ) => {
71
+ const resized = rescaleAndNormalize ( imgTensor , scale )
88
72
const { prob, regions } = PNet ( resized , params )
89
73
90
74
const scores = tf . unstack ( prob , 3 ) [ 1 ]
91
75
const [ sh , sw ] = scores . shape . slice ( 1 )
92
76
const [ rh , rw ] = regions . shape . slice ( 1 )
93
77
78
+ return {
79
+ scoresTensor : scores . as2D ( sh , sw ) ,
80
+ regionsTensor : regions . as3D ( rh , rw , 4 )
81
+ }
82
+ } )
94
83
95
- const boxes = extractBoundingBoxes (
96
- scores . as2D ( sh , sw ) ,
97
- regions . as3D ( rh , rw , 4 ) ,
98
- scale ,
99
- scoreThreshold
100
- )
84
+ const boundingBoxes = extractBoundingBoxes (
85
+ scoresTensor ,
86
+ regionsTensor ,
87
+ scale ,
88
+ scoreThreshold
89
+ )
101
90
102
- // TODO: debug
103
- if ( ! boxes ) {
104
- console . log ( 'no boxes for scale' , scale )
105
- return
106
- }
107
- // TODO: debug
108
- i === 0 && ( window . boxes = boxes . dataSync ( ) )
91
+ scoresTensor . dispose ( )
92
+ regionsTensor . dispose ( )
109
93
110
- } )
94
+ if ( ! boundingBoxes . length ) {
95
+ return [ ]
96
+ }
111
97
98
+ const indices = nms (
99
+ boundingBoxes . map ( bbox => bbox . cell ) ,
100
+ boundingBoxes . map ( bbox => bbox . score ) ,
101
+ 0.5
102
+ )
103
+
104
+ return indices . map ( boxIdx => boundingBoxes [ boxIdx ] )
112
105
} )
113
- }
106
+
107
+ const allBoxes = boxesForScale . reduce (
108
+ ( all , boxes ) => all . concat ( boxes )
109
+ )
110
+
111
+ if ( allBoxes . length > 0 ) {
112
+ const indices = nms (
113
+ allBoxes . map ( bbox => bbox . cell ) ,
114
+ allBoxes . map ( bbox => bbox . score ) ,
115
+ 0.7
116
+ )
117
+
118
+ const finalBoxes = indices
119
+ . map ( idx => allBoxes [ idx ] )
120
+ . map ( ( { cell, region, score } ) => ( {
121
+ box : new BoundingBox (
122
+ cell . left + ( region . left * cell . width ) ,
123
+ cell . right + ( region . right * cell . width ) ,
124
+ cell . top + ( region . top * cell . height ) ,
125
+ cell . bottom + ( region . bottom * cell . height ) ,
126
+ ) . toSquare ( ) . round ( ) ,
127
+ score
128
+ } ) )
129
+
130
+ return finalBoxes
131
+ }
132
+
133
+ return [ ]
134
+
135
+ }
0 commit comments