Skip to content

Commit d5e57cc

Browse files
committed
[NeoMathEngine] AVX512 add vectorHSwish
Signed-off-by: Kirill Golikov <[email protected]>
1 parent fb8261f commit d5e57cc

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,10 @@ inline void vectorSigmoid( const float* first, float* result, int vectorSize )
11551155

11561156
inline void vectorHSwish( const float* first, float* result, int vectorSize )
11571157
{
1158-
if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) {
1158+
if( CCPUInfo::HasAvx512 && vectorSize >= NeoML::Avx512::VectorMathMinSize ) {
1159+
NeoML::Avx512::vectorHSwish( first, result, vectorSize );
1160+
return;
1161+
} else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) {
11591162
NeoML::Avx2::vectorHSwish( first, result, vectorSize );
11601163
return;
11611164
}

NeoMathEngine/src/CPU/x86/avx512/Avx512Functions.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace NeoML {
2323

2424
namespace Avx512 {
2525

26-
// The minimum vector size recommended for using AVX vector functions
26+
// The minimum vector size recommended for using AVX512 vector functions
2727
static constexpr int VectorMathMinSize = 32;
2828

2929
void dataCopy( float* dst, const float* src, int vectorSize );
@@ -40,6 +40,8 @@ void vectorReLU( const float* first, float* result, int vectorSize );
4040

4141
void vectorReLU( const float* first, float* result, int vectorSize, float threshold );
4242

43+
void vectorHSwish( const float* first, float* result, int vectorSize );
44+
4345
} // namespace Avx512
4446

4547
} // namespace NeoML

NeoMathEngine/src/CPU/x86/avx512/Avx512VectorFunctions.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,65 @@ void vectorReLU( const float* first, float* result, int vectorSize, float thresh
238238
}
239239
}
240240

241+
void vectorHSwish( const float* first, float* result, int vectorSize )
242+
{
243+
const __m512 minusThreeSimd = _mm512_set1_ps( -3.f );
244+
const __m512 threeSimd = _mm512_set1_ps( 3.f );
245+
const __m512 oneSixthSimd = _mm512_set1_ps( 1.f / 6.f );
246+
const __m512 zeroSimd = _mm512_setzero_ps();
247+
248+
//for( int i = 0; i < nonSseSize; ++i ) {
249+
// if( *first <= -3.f ) {
250+
// *result = 0.f;
251+
// } else if( *first >= 3.f ) {
252+
// *result = *first;
253+
// } else {
254+
// *result = *first * ( 1. / 6. ) * ( *first + 3 );
255+
// }
256+
// ++result;
257+
// ++first;
258+
//}
259+
260+
while( vectorSize >= AvxBlockSize ) {
261+
const __m512 firstSimd = _mm512_loadu_ps( first );
262+
263+
const __mmask16 middleMask = _mm512_cmp_ps_mask( firstSimd, minusThreeSimd, _CMP_GT_OQ ); // ( first > -3. )
264+
const __mmask16 rightMask = _mm512_cmp_ps_mask( firstSimd, threeSimd, _CMP_LT_OQ ); // ( first < 3. )
265+
266+
const __m512 middleSimd = _mm512_mask_blend_ps( middleMask, zeroSimd/*else*/, firstSimd ); // result = ( first > -3. ) ? first : 0.
267+
268+
const __m512 resultSimd = _mm512_mask_mul_ps( // result = ( middleMask & rightMask ) ? ( first * ( 1. / 6. ) ) * ( first + 3. ) : middleSimd
269+
middleSimd /*else*/,
270+
middleMask & rightMask,
271+
_mm512_mul_ps( firstSimd, oneSixthSimd ), // ( first * ( 1. / 6. ) ) *
272+
_mm512_add_ps( firstSimd, threeSimd ) ); // ( first + 3 )
273+
274+
_mm512_storeu_ps( result, resultSimd );
275+
276+
first += AvxBlockSize;
277+
result += AvxBlockSize;
278+
vectorSize -= AvxBlockSize;
279+
}
280+
281+
if( vectorSize > 0 ) {
282+
const __mmask16 mask = AVX512_IO_MASK( vectorSize );
283+
284+
const __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first );
285+
286+
const __mmask16 middleMask = _mm512_cmp_ps_mask( firstSimd, minusThreeSimd, _CMP_GT_OQ ); // ( first > -3. )
287+
const __mmask16 rightMask = _mm512_cmp_ps_mask( firstSimd, threeSimd, _CMP_LT_OQ ); // ( first < 3. )
288+
289+
const __m512 middleSimd = _mm512_mask_blend_ps( middleMask, zeroSimd/*else*/, firstSimd ); // result = ( first > -3. ) ? first : 0.
290+
291+
const __m512 resultSimd = _mm512_mask_mul_ps( // result = ( middleMask & rightMask ) ? ( first * ( 1. / 6. ) ) * ( first + 3. ) : middleSimd
292+
middleSimd /*else*/,
293+
middleMask & rightMask,
294+
_mm512_mul_ps( firstSimd, oneSixthSimd ), // ( first * ( 1. / 6. ) ) *
295+
_mm512_add_ps( firstSimd, threeSimd ) ); // ( first + 3 )
296+
297+
_mm512_mask_store_ps( result, mask, resultSimd );
298+
}
299+
}
241300

242301
} // namespace Avx512
243302

0 commit comments

Comments
 (0)