Skip to content

Commit 4612399

Browse files
committed
[NeoMathEngine] speed-up VectorBenchmarkTest
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 84284be commit 4612399

File tree

1 file changed

+119
-77
lines changed

1 file changed

+119
-77
lines changed

NeoMathEngine/test/src/inference/VectorBenchmark.cpp

Lines changed: 119 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -35,76 +35,118 @@ static const char* vectorFunctionsNames[]{
3535
"VectorHSwish"
3636
};
3737

38-
static void vectorBenchmark( int function, int testCount, int vectorSize, const CInterval& vectorValuesInterval,
39-
int seed, IPerformanceCounters& counters, std::ofstream& fout )
38+
//------------------------------------------------------------------------------------------------------------
39+
40+
class VectorBenchmarkParams final {
41+
public:
42+
int Function = -2;
43+
int TestCount = -1;
44+
int VectorSize = -1;
45+
IPerformanceCounters* Counters = nullptr;
46+
std::ofstream FOut{};
47+
48+
VectorBenchmarkParams( int function, int testCount, int vectorSize,
49+
const CInterval& valuesInterval, int seed );
50+
~VectorBenchmarkParams() { delete Counters; }
51+
52+
void SetNextSeedForFunction( int function, int seed );
53+
54+
CFloatWrapper GetInputBuffer() { return CFloatWrapper( MathEngine(), input.data(), VectorSize ); }
55+
CFloatWrapper GetSecondBuffer() { return CFloatWrapper( MathEngine(), second.data(), VectorSize ); }
56+
CFloatWrapper GetResultBuffer() { return CFloatWrapper( MathEngine(), result.data(), VectorSize ); }
57+
CFloatWrapper GetZeroVal() { float zero = 0; return CFloatWrapper( MathEngine(), &zero, 1 ); }
58+
CFloatWrapper GetMulVal() {
59+
float multiplier = static_cast<float>( random.Uniform( 1, valuesInterval.End ) );
60+
return CFloatWrapper( MathEngine(), &multiplier, 1 );
61+
}
62+
63+
private:
64+
const CInterval& valuesInterval;
65+
CRandom random;
66+
67+
std::vector<float> input;
68+
std::vector<float> second;
69+
std::vector<float> result;
70+
};
71+
72+
VectorBenchmarkParams::VectorBenchmarkParams( int function, int testCount, int vectorSize,
73+
const CInterval& valuesInterval, int seed ) :
74+
Function( function ),
75+
TestCount( testCount ),
76+
VectorSize( vectorSize ),
77+
Counters( MathEngine().CreatePerformanceCounters() ),
78+
FOut( std::ofstream( "VectorBenchmarkTest.csv", std::ios::app ) ),
79+
valuesInterval( valuesInterval )
80+
{
81+
FOut << "\n---------------------------" << std::endl;
82+
input.resize( vectorSize );
83+
second.resize( vectorSize );
84+
result.resize( vectorSize );
85+
SetNextSeedForFunction( function, seed );
86+
}
87+
88+
void VectorBenchmarkParams::SetNextSeedForFunction( int function, int seed )
4089
{
41-
CRandom random( seed );
42-
CREATE_FILL_FLOAT_ARRAY( input, vectorValuesInterval.Begin, vectorValuesInterval.End, vectorSize, random )
43-
CREATE_FILL_FLOAT_ARRAY( second, vectorValuesInterval.Begin, vectorValuesInterval.End, vectorSize, random )
44-
std::vector<float> result( vectorSize, 0 );
45-
46-
float zero = 0;
47-
float multiplier = static_cast<float>( random.Uniform( 1, vectorValuesInterval.End ) );
48-
CFloatWrapper zeroVal( MathEngine(), &zero, 1 );
49-
CFloatWrapper mulVal( MathEngine(), &multiplier, 1 );
50-
ASSERT_EXPR( ( ( CConstFloatHandle )mulVal ).GetValueAt( 0 ) > 0 );
51-
52-
if( function == -1 ) { // warm-up
53-
return;
90+
Function = function;
91+
random = CRandom( seed );
92+
for( int i = 0; i < VectorSize; ++i ) {
93+
input[i] = static_cast<float>( random.Uniform( valuesInterval.Begin, valuesInterval.End ) );
94+
second[i] = static_cast<float>( random.Uniform( valuesInterval.Begin, valuesInterval.End ) );
95+
result[i] = 0;
96+
}
97+
}
98+
99+
//------------------------------------------------------------------------------------------------------------
100+
101+
static double vectorBenchmark( VectorBenchmarkParams& params )
102+
{
103+
CFloatWrapper zeroVal = params.GetInputBuffer();
104+
CFloatWrapper mulVal = params.GetMulVal();
105+
CConstFloatHandle mulHandle = mulVal;
106+
ASSERT_EXPR( mulHandle.GetValueAt( 0 ) > 0 );
107+
108+
CFloatWrapper input = params.GetInputBuffer();
109+
CFloatWrapper second = params.GetSecondBuffer();
110+
CFloatWrapper result = params.GetResultBuffer();
111+
const int vectorSize = params.VectorSize;
112+
113+
if( params.Function == -1 ) { // warm-up
114+
MathEngine().VectorCopy( result, input, vectorSize );
115+
MathEngine().VectorFill( result, vectorSize, mulVal );
116+
MathEngine().VectorAdd( input, second, result, vectorSize );
117+
MathEngine().VectorAddValue( input, result, vectorSize, mulVal );
118+
MathEngine().VectorMultiply( input, second, vectorSize, mulVal );
119+
MathEngine().VectorEltwiseMultiply( input, second, result, vectorSize );
120+
MathEngine().VectorEltwiseMultiplyAdd( input, second, result, vectorSize );
121+
MathEngine().VectorReLU( input, result, vectorSize, zeroVal ); //Threshold == 0
122+
MathEngine().VectorReLU( input, result, vectorSize, mulVal ); //Threshold > 0
123+
MathEngine().VectorHSwish( input, result, vectorSize );
124+
return 0;
54125
}
55126

56-
counters.Synchronise();
57-
58-
for( int i = 0; i < testCount; ++i ) {
59-
switch( function ) {
60-
case 0:
61-
MathEngine().VectorCopy( CARRAY_FLOAT_WRAPPER( result ), CARRAY_FLOAT_WRAPPER( input ), vectorSize );
62-
break;
63-
case 1:
64-
MathEngine().VectorFill( CARRAY_FLOAT_WRAPPER( result ), vectorSize, mulVal );
65-
break;
66-
case 2:
67-
MathEngine().VectorAdd( CARRAY_FLOAT_WRAPPER( input ), CARRAY_FLOAT_WRAPPER( second ),
68-
CARRAY_FLOAT_WRAPPER( result ), vectorSize );
69-
break;
70-
case 3:
71-
MathEngine().VectorAddValue( CARRAY_FLOAT_WRAPPER( input ), CARRAY_FLOAT_WRAPPER( result ),
72-
vectorSize, mulVal );
73-
break;
74-
case 4:
75-
MathEngine().VectorMultiply( CARRAY_FLOAT_WRAPPER( input ),
76-
CARRAY_FLOAT_WRAPPER( second ), vectorSize, mulVal );
77-
break;
78-
case 5:
79-
MathEngine().VectorEltwiseMultiply( CARRAY_FLOAT_WRAPPER( input ),
80-
CARRAY_FLOAT_WRAPPER( second ), CARRAY_FLOAT_WRAPPER( result ), vectorSize );
81-
break;
82-
case 6:
83-
MathEngine().VectorEltwiseMultiplyAdd( CARRAY_FLOAT_WRAPPER( input ),
84-
CARRAY_FLOAT_WRAPPER( second ), CARRAY_FLOAT_WRAPPER( result ), vectorSize );
85-
break;
86-
case 7:
87-
MathEngine().VectorReLU( CARRAY_FLOAT_WRAPPER( input ), CARRAY_FLOAT_WRAPPER( result ),
88-
vectorSize, zeroVal ); //Threshold == 0
89-
break;
90-
case 8:
91-
MathEngine().VectorReLU( CARRAY_FLOAT_WRAPPER( input ), CARRAY_FLOAT_WRAPPER( result ),
92-
vectorSize, mulVal ); //Threshold > 0
93-
break;
94-
case 9:
95-
MathEngine().VectorHSwish( CARRAY_FLOAT_WRAPPER( input ), CARRAY_FLOAT_WRAPPER( result ),
96-
vectorSize );
97-
break;
127+
params.Counters->Synchronise();
128+
129+
for( int i = 0; i < params.TestCount; ++i ) {
130+
switch( params.Function ) {
131+
case 0: MathEngine().VectorCopy( result, input, vectorSize ); break;
132+
case 1: MathEngine().VectorFill( result, vectorSize, mulVal ); break;
133+
case 2: MathEngine().VectorAdd( input, second, result, vectorSize ); break;
134+
case 3: MathEngine().VectorAddValue( input, result, vectorSize, mulVal ); break;
135+
case 4: MathEngine().VectorMultiply( input, second, vectorSize, mulVal ); break;
136+
case 5: MathEngine().VectorEltwiseMultiply( input, second, result, vectorSize ); break;
137+
case 6: MathEngine().VectorEltwiseMultiplyAdd( input, second, result, vectorSize ); break;
138+
case 7: MathEngine().VectorReLU( input, result, vectorSize, zeroVal ); break; //Threshold == 0
139+
case 8: MathEngine().VectorReLU( input, result, vectorSize, mulVal ); break; //Threshold > 0
140+
case 9: MathEngine().VectorHSwish( input, result, vectorSize ); break;
98141
default:
99142
ASSERT_EXPR( false );
100143
}
101144
}
102145

103-
counters.Synchronise();
104-
const double time = double( counters[0].Value ) / 1000000 / testCount; // average time in milliseconds
105-
106-
GTEST_LOG_( INFO ) << vectorFunctionsNames[function] << ", " << time;
107-
fout << vectorFunctionsNames[function] << "," << time << "\n";
146+
params.Counters->Synchronise();
147+
const double time = double( ( *params.Counters )[0].Value ) / 1000000 / params.TestCount; // average time in milliseconds
148+
params.FOut << time << ",";
149+
return time;
108150
}
109151

110152
} // namespace NeoMLTest
@@ -129,34 +171,34 @@ INSTANTIATE_TEST_CASE_P( CMathEngineVectorBenchmarkTestInstantiation, CMathEngin
129171
"VectorValues = (-10..10);"
130172
),
131173
CTestParams(
132-
"TestCount = 100000;"
174+
"TestCount = 1000;"
133175
"RepeatCount = 10;"
134-
"VectorSize = 11796480;"
176+
"VectorSize = 1179648;"
135177
"VectorValues = (-1..1);"
136178
)
137179
)
138180
);
139181

140182
TEST_P( CMathEngineVectorBenchmarkTest, DISABLED_Random )
141183
{
142-
CTestParams params = GetParam();
143-
144-
const int testCount = params.GetValue<int>( "TestCount" );
145-
const int repeatCount = params.GetValue<int>( "RepeatCount" );
146-
const int vectorSize = params.GetValue<int>( "VectorSize" );
147-
const CInterval vectorValuesInterval = params.GetInterval( "VectorValues" );
184+
CTestParams testParams = GetParam();
185+
const int testCount = testParams.GetValue<int>( "TestCount" );
186+
const int repeatCount = testParams.GetValue<int>( "RepeatCount" );
187+
const int vectorSize = testParams.GetValue<int>( "VectorSize" );
188+
const CInterval valuesInterval = testParams.GetInterval( "VectorValues" );
148189

149-
IPerformanceCounters* counters = MathEngine().CreatePerformanceCounters();
150-
std::ofstream fout( "VectorBenchmarkTest.csv", std::ios::app );
151-
fout << "---------------------------\n";
152-
153-
vectorBenchmark( /*warm-up*/-1, testCount, vectorSize, vectorValuesInterval, 282, *counters, fout);
190+
VectorBenchmarkParams params( /*warm-up*/-1, testCount, vectorSize, valuesInterval, 282 );
191+
vectorBenchmark( params );
154192

155193
for( int function = 0; function < 10; ++function ) {
194+
params.FOut << std::endl << vectorFunctionsNames[function] << ",";
195+
196+
double timeSum = 0;
156197
for( int test = 0; test < repeatCount; ++test ) {
157198
const int seed = 282 + test * 10000 + test % 3;
158-
vectorBenchmark( function, testCount, vectorSize, vectorValuesInterval, seed, *counters, fout );
199+
params.SetNextSeedForFunction( function, seed );
200+
timeSum += vectorBenchmark( params );
159201
}
202+
GTEST_LOG_( INFO ) << vectorFunctionsNames[function] << "\t" << timeSum;
160203
}
161-
delete counters;
162204
}

0 commit comments

Comments
 (0)