@@ -72,7 +72,7 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
72
72
int outCn = outShape[1 ];
73
73
size_t outPlaneSize = outShape[2 ]*outShape[3 ];
74
74
float r0 = 1 .f , r1 = 1 .f , r2 = 1 .f ;
75
- __m256 vr0 = _mm256_set1_ps (1 .f ), vr1 = vr0, vr2 = vr0, z = _mm256_setzero_ps ();
75
+ __m128 vr0 = _mm_set1_ps (1 .f ), vr1 = vr0, vr2 = vr0, z = _mm_setzero_ps ();
76
76
77
77
// now compute dot product of the weights
78
78
// and im2row-transformed part of the tensor
@@ -104,9 +104,9 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
104
104
r0 = relu[i];
105
105
r1 = relu[i+1 ];
106
106
r2 = relu[i+2 ];
107
- vr0 = _mm256_set1_ps (r0);
108
- vr1 = _mm256_set1_ps (r1);
109
- vr2 = _mm256_set1_ps (r2);
107
+ vr0 = _mm_set1_ps (r0);
108
+ vr1 = _mm_set1_ps (r1);
109
+ vr2 = _mm_set1_ps (r2);
110
110
}
111
111
112
112
int j = 0 ;
@@ -156,38 +156,38 @@ void fastConv( const float* weights, size_t wstep, const float* bias,
156
156
t1 = _mm256_add_ps (t1, _mm256_permute2f128_ps (t1, t1, 1 ));
157
157
t2 = _mm256_add_ps (t2, _mm256_permute2f128_ps (t2, t2, 1 ));
158
158
159
- __m256 s0, s1, s2;
159
+ __m128 s0, s1, s2;
160
160
161
161
if ( initOutput )
162
162
{
163
- s0 = _mm256_set1_ps (bias0);
164
- s1 = _mm256_set1_ps (bias1);
165
- s2 = _mm256_set1_ps (bias2);
163
+ s0 = _mm_set1_ps (bias0);
164
+ s1 = _mm_set1_ps (bias1);
165
+ s2 = _mm_set1_ps (bias2);
166
166
}
167
167
else
168
168
{
169
- s0 = _mm256_castps128_ps256 ( _mm_loadu_ps (outptr0 + j) );
170
- s1 = _mm256_castps128_ps256 ( _mm_loadu_ps (outptr1 + j) );
171
- s2 = _mm256_castps128_ps256 ( _mm_loadu_ps (outptr2 + j) );
169
+ s0 = _mm_loadu_ps (outptr0 + j);
170
+ s1 = _mm_loadu_ps (outptr1 + j);
171
+ s2 = _mm_loadu_ps (outptr2 + j);
172
172
}
173
173
174
- s0 = _mm256_add_ps (s0, t0 );
175
- s1 = _mm256_add_ps (s1, t1 );
176
- s2 = _mm256_add_ps (s2, t2 );
174
+ s0 = _mm_add_ps (s0, _mm256_castps256_ps128 (t0) );
175
+ s1 = _mm_add_ps (s1, _mm256_castps256_ps128 (t1) );
176
+ s2 = _mm_add_ps (s2, _mm256_castps256_ps128 (t2) );
177
177
178
178
if ( relu )
179
179
{
180
- __m256 m0 = _mm256_cmp_ps (s0, z, _CMP_GT_OS);
181
- __m256 m1 = _mm256_cmp_ps (s1, z, _CMP_GT_OS);
182
- __m256 m2 = _mm256_cmp_ps (s2, z, _CMP_GT_OS);
183
- s0 = _mm256_xor_ps (s0, _mm256_andnot_ps (m0, _mm256_xor_ps ( _mm256_mul_ps (s0, vr0), s0)));
184
- s1 = _mm256_xor_ps (s1, _mm256_andnot_ps (m1, _mm256_xor_ps ( _mm256_mul_ps (s1, vr1), s1)));
185
- s2 = _mm256_xor_ps (s2, _mm256_andnot_ps (m2, _mm256_xor_ps ( _mm256_mul_ps (s2, vr2), s2)));
180
+ __m128 m0 = _mm_cmp_ps (s0, z, _CMP_GT_OS);
181
+ __m128 m1 = _mm_cmp_ps (s1, z, _CMP_GT_OS);
182
+ __m128 m2 = _mm_cmp_ps (s2, z, _CMP_GT_OS);
183
+ s0 = _mm_xor_ps (s0, _mm_andnot_ps (m0, _mm_xor_ps ( _mm_mul_ps (s0, vr0), s0)));
184
+ s1 = _mm_xor_ps (s1, _mm_andnot_ps (m1, _mm_xor_ps ( _mm_mul_ps (s1, vr1), s1)));
185
+ s2 = _mm_xor_ps (s2, _mm_andnot_ps (m2, _mm_xor_ps ( _mm_mul_ps (s2, vr2), s2)));
186
186
}
187
187
188
- _mm_storeu_ps (outptr0 + j, _mm256_castps256_ps128 (s0) );
189
- _mm_storeu_ps (outptr1 + j, _mm256_castps256_ps128 (s1) );
190
- _mm_storeu_ps (outptr2 + j, _mm256_castps256_ps128 (s2) );
188
+ _mm_storeu_ps (outptr0 + j, s0 );
189
+ _mm_storeu_ps (outptr1 + j, s1 );
190
+ _mm_storeu_ps (outptr2 + j, s2 );
191
191
}
192
192
193
193
for ( ; j < blockSize; j++ )
@@ -294,11 +294,63 @@ void fastGEMM1T( const float* vec, const float* weights,
294
294
_mm256_zeroupper ();
295
295
}
296
296
297
+
297
298
void fastGEMM ( const float * aptr, size_t astep, const float * bptr,
298
299
size_t bstep, float * cptr, size_t cstep,
299
300
int ma, int na, int nb )
300
301
{
301
302
int n = 0 ;
303
+
304
+ #if CV_AVX_512F
305
+ for ( ; n <= nb - 32 ; n += 32 )
306
+ {
307
+ for ( int m = 0 ; m < ma; m += 4 )
308
+ {
309
+ const float * aptr0 = aptr + astep*m;
310
+ const float * aptr1 = aptr + astep*std::min (m+1 , ma-1 );
311
+ const float * aptr2 = aptr + astep*std::min (m+2 , ma-1 );
312
+ const float * aptr3 = aptr + astep*std::min (m+3 , ma-1 );
313
+
314
+ float * cptr0 = cptr + cstep*m;
315
+ float * cptr1 = cptr + cstep*std::min (m+1 , ma-1 );
316
+ float * cptr2 = cptr + cstep*std::min (m+2 , ma-1 );
317
+ float * cptr3 = cptr + cstep*std::min (m+3 , ma-1 );
318
+
319
+ __m512 d00 = _mm512_setzero_ps (), d01 = _mm512_setzero_ps ();
320
+ __m512 d10 = _mm512_setzero_ps (), d11 = _mm512_setzero_ps ();
321
+ __m512 d20 = _mm512_setzero_ps (), d21 = _mm512_setzero_ps ();
322
+ __m512 d30 = _mm512_setzero_ps (), d31 = _mm512_setzero_ps ();
323
+
324
+ for ( int k = 0 ; k < na; k++ )
325
+ {
326
+ __m512 a0 = _mm512_set1_ps (aptr0[k]);
327
+ __m512 a1 = _mm512_set1_ps (aptr1[k]);
328
+ __m512 a2 = _mm512_set1_ps (aptr2[k]);
329
+ __m512 a3 = _mm512_set1_ps (aptr3[k]);
330
+ __m512 b0 = _mm512_loadu_ps (bptr + k*bstep + n);
331
+ __m512 b1 = _mm512_loadu_ps (bptr + k*bstep + n + 16 );
332
+ d00 = _mm512_fmadd_ps (a0, b0, d00);
333
+ d01 = _mm512_fmadd_ps (a0, b1, d01);
334
+ d10 = _mm512_fmadd_ps (a1, b0, d10);
335
+ d11 = _mm512_fmadd_ps (a1, b1, d11);
336
+ d20 = _mm512_fmadd_ps (a2, b0, d20);
337
+ d21 = _mm512_fmadd_ps (a2, b1, d21);
338
+ d30 = _mm512_fmadd_ps (a3, b0, d30);
339
+ d31 = _mm512_fmadd_ps (a3, b1, d31);
340
+ }
341
+
342
+ _mm512_storeu_ps (cptr0 + n, d00);
343
+ _mm512_storeu_ps (cptr0 + n + 16 , d01);
344
+ _mm512_storeu_ps (cptr1 + n, d10);
345
+ _mm512_storeu_ps (cptr1 + n + 16 , d11);
346
+ _mm512_storeu_ps (cptr2 + n, d20);
347
+ _mm512_storeu_ps (cptr2 + n + 16 , d21);
348
+ _mm512_storeu_ps (cptr3 + n, d30);
349
+ _mm512_storeu_ps (cptr3 + n + 16 , d31);
350
+ }
351
+ }
352
+ #endif
353
+
302
354
for ( ; n <= nb - 16 ; n += 16 )
303
355
{
304
356
for ( int m = 0 ; m < ma; m += 4 )
0 commit comments