@@ -238,6 +238,65 @@ void vectorReLU( const float* first, float* result, int vectorSize, float thresh
238
238
}
239
239
}
240
240
241
+ void vectorHSwish ( const float * first, float * result, int vectorSize )
242
+ {
243
+ const __m512 minusThreeSimd = _mm512_set1_ps ( -3 .f );
244
+ const __m512 threeSimd = _mm512_set1_ps ( 3 .f );
245
+ const __m512 oneSixthSimd = _mm512_set1_ps ( 1 .f / 6 .f );
246
+ const __m512 zeroSimd = _mm512_setzero_ps ();
247
+
248
+ // for( int i = 0; i < nonSseSize; ++i ) {
249
+ // if( *first <= -3.f ) {
250
+ // *result = 0.f;
251
+ // } else if( *first >= 3.f ) {
252
+ // *result = *first;
253
+ // } else {
254
+ // *result = *first * ( 1. / 6. ) * ( *first + 3 );
255
+ // }
256
+ // ++result;
257
+ // ++first;
258
+ // }
259
+
260
+ while ( vectorSize >= AvxBlockSize ) {
261
+ const __m512 firstSimd = _mm512_loadu_ps ( first );
262
+
263
+ const __mmask16 middleMask = _mm512_cmp_ps_mask ( firstSimd, minusThreeSimd, _CMP_GT_OQ ); // ( first > -3. )
264
+ const __mmask16 rightMask = _mm512_cmp_ps_mask ( firstSimd, threeSimd, _CMP_LT_OQ ); // ( first < 3. )
265
+
266
+ const __m512 middleSimd = _mm512_mask_blend_ps ( middleMask, zeroSimd/* else*/ , firstSimd ); // result = ( first > -3. ) ? first : 0.
267
+
268
+ const __m512 resultSimd = _mm512_mask_mul_ps ( // result = ( middleMask & rightMask ) ? ( first * ( 1. / 6. ) ) * ( first + 3. ) : middleSimd
269
+ middleSimd /* else*/ ,
270
+ middleMask & rightMask,
271
+ _mm512_mul_ps ( firstSimd, oneSixthSimd ), // ( first * ( 1. / 6. ) ) *
272
+ _mm512_add_ps ( firstSimd, threeSimd ) ); // ( first + 3 )
273
+
274
+ _mm512_storeu_ps ( result, resultSimd );
275
+
276
+ first += AvxBlockSize;
277
+ result += AvxBlockSize;
278
+ vectorSize -= AvxBlockSize;
279
+ }
280
+
281
+ if ( vectorSize > 0 ) {
282
+ const __mmask16 mask = AVX512_IO_MASK ( vectorSize );
283
+
284
+ const __m512 firstSimd = _mm512_mask_loadu_ps ( zeroSimd, mask, first );
285
+
286
+ const __mmask16 middleMask = _mm512_cmp_ps_mask ( firstSimd, minusThreeSimd, _CMP_GT_OQ ); // ( first > -3. )
287
+ const __mmask16 rightMask = _mm512_cmp_ps_mask ( firstSimd, threeSimd, _CMP_LT_OQ ); // ( first < 3. )
288
+
289
+ const __m512 middleSimd = _mm512_mask_blend_ps ( middleMask, zeroSimd/* else*/ , firstSimd ); // result = ( first > -3. ) ? first : 0.
290
+
291
+ const __m512 resultSimd = _mm512_mask_mul_ps ( // result = ( middleMask & rightMask ) ? ( first * ( 1. / 6. ) ) * ( first + 3. ) : middleSimd
292
+ middleSimd /* else*/ ,
293
+ middleMask & rightMask,
294
+ _mm512_mul_ps ( firstSimd, oneSixthSimd ), // ( first * ( 1. / 6. ) ) *
295
+ _mm512_add_ps ( firstSimd, threeSimd ) ); // ( first + 3 )
296
+
297
+ _mm512_mask_store_ps ( result, mask, resultSimd );
298
+ }
299
+ }
241
300
242
301
} // namespace Avx512
243
302
0 commit comments