@@ -20,19 +20,42 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
20
20
}
21
21
}
22
22
23
- function extractPointwiseConvParams ( channelsIn : number , channelsOut : number ) : FaceDetectionNet . PointwiseConvParams {
24
- const filters = tf . tensor4d ( extractWeights ( channelsIn * channelsOut ) , [ 1 , 1 , channelsIn , channelsOut ] )
25
- const batch_norm_offset = tf . tensor1d ( extractWeights ( channelsOut ) )
23
+ function extractConvWithBiasParams (
24
+ channelsIn : number ,
25
+ channelsOut : number ,
26
+ filterSize : number
27
+ ) : FaceDetectionNet . ConvWithBiasParams {
28
+ const filters = tf . tensor4d (
29
+ extractWeights ( channelsIn * channelsOut * filterSize * filterSize ) ,
30
+ [ filterSize , filterSize , channelsIn , channelsOut ]
31
+ )
32
+ const bias = tf . tensor1d ( extractWeights ( channelsOut ) )
26
33
27
34
return {
28
35
filters,
29
- batch_norm_offset
36
+ bias
37
+ }
38
+ }
39
+
40
+ function extractPointwiseConvParams (
41
+ channelsIn : number ,
42
+ channelsOut : number ,
43
+ filterSize : number
44
+ ) : FaceDetectionNet . PointwiseConvParams {
45
+ const {
46
+ filters,
47
+ bias
48
+ } = extractConvWithBiasParams ( channelsIn , channelsOut , filterSize )
49
+
50
+ return {
51
+ filters,
52
+ batch_norm_offset : bias
30
53
}
31
54
}
32
55
33
56
function extractConvPairParams ( channelsIn : number , channelsOut : number ) : FaceDetectionNet . MobileNetV1 . ConvPairParams {
34
57
const depthwise_conv_params = extractDepthwiseConvParams ( channelsIn )
35
- const pointwise_conv_params = extractPointwiseConvParams ( channelsIn , channelsOut )
58
+ const pointwise_conv_params = extractPointwiseConvParams ( channelsIn , channelsOut , 1 )
36
59
37
60
return {
38
61
depthwise_conv_params,
@@ -42,11 +65,7 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
42
65
43
66
function extractMobilenetV1Params ( ) : FaceDetectionNet . MobileNetV1 . Params {
44
67
45
- const conv_0_params = {
46
- filters : tf . tensor4d ( extractWeights ( 3 * 3 * 3 * 32 ) , [ 3 , 3 , 3 , 32 ] ) ,
47
- batch_norm_offset : tf . tensor1d ( extractWeights ( 32 ) )
48
-
49
- }
68
+ const conv_0_params = extractPointwiseConvParams ( 3 , 32 , 3 )
50
69
51
70
const channelNumPairs = [
52
71
[ 32 , 64 ] ,
@@ -75,32 +94,101 @@ function extractorsFactory(extractWeights: (numWeights: number) => Float32Array)
75
94
76
95
}
77
96
97
+ function extractPredictionLayerParams ( ) : FaceDetectionNet . PredictionParams {
98
+ const conv_0_params = extractPointwiseConvParams ( 1024 , 256 , 1 )
99
+ const conv_1_params = extractPointwiseConvParams ( 256 , 512 , 3 )
100
+ const conv_2_params = extractPointwiseConvParams ( 512 , 128 , 1 )
101
+ const conv_3_params = extractPointwiseConvParams ( 128 , 256 , 3 )
102
+ const conv_4_params = extractPointwiseConvParams ( 256 , 128 , 1 )
103
+ const conv_5_params = extractPointwiseConvParams ( 128 , 256 , 3 )
104
+ const conv_6_params = extractPointwiseConvParams ( 256 , 64 , 1 )
105
+ const conv_7_params = extractPointwiseConvParams ( 64 , 128 , 3 )
106
+
107
+ const box_encoding_0_predictor_params = extractConvWithBiasParams ( 512 , 12 , 1 )
108
+ const class_predictor_0_params = extractConvWithBiasParams ( 512 , 9 , 1 )
109
+ const box_encoding_1_predictor_params = extractConvWithBiasParams ( 1024 , 24 , 1 )
110
+ const class_predictor_1_params = extractConvWithBiasParams ( 1024 , 18 , 1 )
111
+ const box_encoding_2_predictor_params = extractConvWithBiasParams ( 512 , 24 , 1 )
112
+ const class_predictor_2_params = extractConvWithBiasParams ( 512 , 18 , 1 )
113
+ const box_encoding_3_predictor_params = extractConvWithBiasParams ( 256 , 24 , 1 )
114
+ const class_predictor_3_params = extractConvWithBiasParams ( 256 , 18 , 1 )
115
+ const box_encoding_4_predictor_params = extractConvWithBiasParams ( 256 , 24 , 1 )
116
+ const class_predictor_4_params = extractConvWithBiasParams ( 256 , 18 , 1 )
117
+ const box_encoding_5_predictor_params = extractConvWithBiasParams ( 128 , 24 , 1 )
118
+ const class_predictor_5_params = extractConvWithBiasParams ( 128 , 18 , 1 )
119
+
120
+ const box_predictor_0_params = {
121
+ box_encoding_predictor_params : box_encoding_0_predictor_params ,
122
+ class_predictor_params : class_predictor_0_params
123
+ }
124
+ const box_predictor_1_params = {
125
+ box_encoding_predictor_params : box_encoding_1_predictor_params ,
126
+ class_predictor_params : class_predictor_1_params
127
+ }
128
+ const box_predictor_2_params = {
129
+ box_encoding_predictor_params : box_encoding_2_predictor_params ,
130
+ class_predictor_params : class_predictor_2_params
131
+ }
132
+ const box_predictor_3_params = {
133
+ box_encoding_predictor_params : box_encoding_3_predictor_params ,
134
+ class_predictor_params : class_predictor_3_params
135
+ }
136
+ const box_predictor_4_params = {
137
+ box_encoding_predictor_params : box_encoding_4_predictor_params ,
138
+ class_predictor_params : class_predictor_4_params
139
+ }
140
+ const box_predictor_5_params = {
141
+ box_encoding_predictor_params : box_encoding_5_predictor_params ,
142
+ class_predictor_params : class_predictor_5_params
143
+ }
144
+
145
+ return {
146
+ conv_0_params,
147
+ conv_1_params,
148
+ conv_2_params,
149
+ conv_3_params,
150
+ conv_4_params,
151
+ conv_5_params,
152
+ conv_6_params,
153
+ conv_7_params,
154
+ box_predictor_0_params,
155
+ box_predictor_1_params,
156
+ box_predictor_2_params,
157
+ box_predictor_3_params,
158
+ box_predictor_4_params,
159
+ box_predictor_5_params
160
+ }
161
+ }
162
+
78
163
79
164
return {
80
- extractMobilenetV1Params
165
+ extractMobilenetV1Params,
166
+ extractPredictionLayerParams
81
167
}
82
168
83
169
}
84
170
85
171
export function extractParams ( weights : Float32Array ) : FaceDetectionNet . NetParams {
86
172
const extractWeights = ( numWeights : number ) : Float32Array => {
87
- console . log ( numWeights )
88
173
const ret = weights . slice ( 0 , numWeights )
89
174
weights = weights . slice ( numWeights )
90
175
return ret
91
176
}
92
177
93
178
const {
94
- extractMobilenetV1Params
179
+ extractMobilenetV1Params,
180
+ extractPredictionLayerParams
95
181
} = extractorsFactory ( extractWeights )
96
182
97
183
const mobilenetv1_params = extractMobilenetV1Params ( )
184
+ const prediction_layer_params = extractPredictionLayerParams ( )
98
185
99
186
if ( weights . length !== 0 ) {
100
187
throw new Error ( `weights remaing after extract: ${ weights . length } ` )
101
188
}
102
189
103
190
return {
104
- mobilenetv1_params
191
+ mobilenetv1_params,
192
+ prediction_layer_params
105
193
}
106
194
}
0 commit comments