@@ -35,76 +35,118 @@ static const char* vectorFunctionsNames[]{
35
35
" VectorHSwish"
36
36
};
37
37
38
- static void vectorBenchmark ( int function, int testCount, int vectorSize, const CInterval& vectorValuesInterval,
39
- int seed, IPerformanceCounters& counters, std::ofstream& fout )
38
+ // ------------------------------------------------------------------------------------------------------------
39
+
40
+ class VectorBenchmarkParams final {
41
+ public:
42
+ int Function = -2 ;
43
+ int TestCount = -1 ;
44
+ int VectorSize = -1 ;
45
+ IPerformanceCounters* Counters = nullptr ;
46
+ std::ofstream FOut{};
47
+
48
+ VectorBenchmarkParams ( int function, int testCount, int vectorSize,
49
+ const CInterval& valuesInterval, int seed );
50
+ ~VectorBenchmarkParams () { delete Counters; }
51
+
52
+ void SetNextSeedForFunction ( int function, int seed );
53
+
54
+ CFloatWrapper GetInputBuffer () { return CFloatWrapper ( MathEngine (), input.data (), VectorSize ); }
55
+ CFloatWrapper GetSecondBuffer () { return CFloatWrapper ( MathEngine (), second.data (), VectorSize ); }
56
+ CFloatWrapper GetResultBuffer () { return CFloatWrapper ( MathEngine (), result.data (), VectorSize ); }
57
+ CFloatWrapper GetZeroVal () { float zero = 0 ; return CFloatWrapper ( MathEngine (), &zero, 1 ); }
58
+ CFloatWrapper GetMulVal () {
59
+ float multiplier = static_cast <float >( random.Uniform ( 1 , valuesInterval.End ) );
60
+ return CFloatWrapper ( MathEngine (), &multiplier, 1 );
61
+ }
62
+
63
+ private:
64
+ const CInterval& valuesInterval;
65
+ CRandom random;
66
+
67
+ std::vector<float > input;
68
+ std::vector<float > second;
69
+ std::vector<float > result;
70
+ };
71
+
72
+ VectorBenchmarkParams::VectorBenchmarkParams ( int function, int testCount, int vectorSize,
73
+ const CInterval& valuesInterval, int seed ) :
74
+ Function ( function ),
75
+ TestCount ( testCount ),
76
+ VectorSize ( vectorSize ),
77
+ Counters ( MathEngine().CreatePerformanceCounters() ),
78
+ FOut ( std::ofstream( " VectorBenchmarkTest.csv" , std::ios::app ) ),
79
+ valuesInterval ( valuesInterval )
80
+ {
81
+ FOut << " \n ---------------------------" << std::endl;
82
+ input.resize ( vectorSize );
83
+ second.resize ( vectorSize );
84
+ result.resize ( vectorSize );
85
+ SetNextSeedForFunction ( function, seed );
86
+ }
87
+
88
+ void VectorBenchmarkParams::SetNextSeedForFunction ( int function, int seed )
40
89
{
41
- CRandom random ( seed );
42
- CREATE_FILL_FLOAT_ARRAY ( input, vectorValuesInterval.Begin , vectorValuesInterval.End , vectorSize, random )
43
- CREATE_FILL_FLOAT_ARRAY ( second, vectorValuesInterval.Begin , vectorValuesInterval.End , vectorSize, random )
44
- std::vector<float > result ( vectorSize, 0 );
45
-
46
- float zero = 0 ;
47
- float multiplier = static_cast <float >( random.Uniform ( 1 , vectorValuesInterval.End ) );
48
- CFloatWrapper zeroVal ( MathEngine (), &zero, 1 );
49
- CFloatWrapper mulVal ( MathEngine (), &multiplier, 1 );
50
- ASSERT_EXPR ( ( ( CConstFloatHandle )mulVal ).GetValueAt ( 0 ) > 0 );
51
-
52
- if ( function == -1 ) { // warm-up
53
- return ;
90
+ Function = function;
91
+ random = CRandom ( seed );
92
+ for ( int i = 0 ; i < VectorSize; ++i ) {
93
+ input[i] = static_cast <float >( random.Uniform ( valuesInterval.Begin , valuesInterval.End ) );
94
+ second[i] = static_cast <float >( random.Uniform ( valuesInterval.Begin , valuesInterval.End ) );
95
+ result[i] = 0 ;
96
+ }
97
+ }
98
+
99
+ // ------------------------------------------------------------------------------------------------------------
100
+
101
+ static double vectorBenchmark ( VectorBenchmarkParams& params )
102
+ {
103
+ CFloatWrapper zeroVal = params.GetInputBuffer ();
104
+ CFloatWrapper mulVal = params.GetMulVal ();
105
+ CConstFloatHandle mulHandle = mulVal;
106
+ ASSERT_EXPR ( mulHandle.GetValueAt ( 0 ) > 0 );
107
+
108
+ CFloatWrapper input = params.GetInputBuffer ();
109
+ CFloatWrapper second = params.GetSecondBuffer ();
110
+ CFloatWrapper result = params.GetResultBuffer ();
111
+ const int vectorSize = params.VectorSize ;
112
+
113
+ if ( params.Function == -1 ) { // warm-up
114
+ MathEngine ().VectorCopy ( result, input, vectorSize );
115
+ MathEngine ().VectorFill ( result, vectorSize, mulVal );
116
+ MathEngine ().VectorAdd ( input, second, result, vectorSize );
117
+ MathEngine ().VectorAddValue ( input, result, vectorSize, mulVal );
118
+ MathEngine ().VectorMultiply ( input, second, vectorSize, mulVal );
119
+ MathEngine ().VectorEltwiseMultiply ( input, second, result, vectorSize );
120
+ MathEngine ().VectorEltwiseMultiplyAdd ( input, second, result, vectorSize );
121
+ MathEngine ().VectorReLU ( input, result, vectorSize, zeroVal ); // Threshold == 0
122
+ MathEngine ().VectorReLU ( input, result, vectorSize, mulVal ); // Threshold > 0
123
+ MathEngine ().VectorHSwish ( input, result, vectorSize );
124
+ return 0 ;
54
125
}
55
126
56
- counters.Synchronise ();
57
-
58
- for ( int i = 0 ; i < testCount; ++i ) {
59
- switch ( function ) {
60
- case 0 :
61
- MathEngine ().VectorCopy ( CARRAY_FLOAT_WRAPPER ( result ), CARRAY_FLOAT_WRAPPER ( input ), vectorSize );
62
- break ;
63
- case 1 :
64
- MathEngine ().VectorFill ( CARRAY_FLOAT_WRAPPER ( result ), vectorSize, mulVal );
65
- break ;
66
- case 2 :
67
- MathEngine ().VectorAdd ( CARRAY_FLOAT_WRAPPER ( input ), CARRAY_FLOAT_WRAPPER ( second ),
68
- CARRAY_FLOAT_WRAPPER ( result ), vectorSize );
69
- break ;
70
- case 3 :
71
- MathEngine ().VectorAddValue ( CARRAY_FLOAT_WRAPPER ( input ), CARRAY_FLOAT_WRAPPER ( result ),
72
- vectorSize, mulVal );
73
- break ;
74
- case 4 :
75
- MathEngine ().VectorMultiply ( CARRAY_FLOAT_WRAPPER ( input ),
76
- CARRAY_FLOAT_WRAPPER ( second ), vectorSize, mulVal );
77
- break ;
78
- case 5 :
79
- MathEngine ().VectorEltwiseMultiply ( CARRAY_FLOAT_WRAPPER ( input ),
80
- CARRAY_FLOAT_WRAPPER ( second ), CARRAY_FLOAT_WRAPPER ( result ), vectorSize );
81
- break ;
82
- case 6 :
83
- MathEngine ().VectorEltwiseMultiplyAdd ( CARRAY_FLOAT_WRAPPER ( input ),
84
- CARRAY_FLOAT_WRAPPER ( second ), CARRAY_FLOAT_WRAPPER ( result ), vectorSize );
85
- break ;
86
- case 7 :
87
- MathEngine ().VectorReLU ( CARRAY_FLOAT_WRAPPER ( input ), CARRAY_FLOAT_WRAPPER ( result ),
88
- vectorSize, zeroVal ); // Threshold == 0
89
- break ;
90
- case 8 :
91
- MathEngine ().VectorReLU ( CARRAY_FLOAT_WRAPPER ( input ), CARRAY_FLOAT_WRAPPER ( result ),
92
- vectorSize, mulVal ); // Threshold > 0
93
- break ;
94
- case 9 :
95
- MathEngine ().VectorHSwish ( CARRAY_FLOAT_WRAPPER ( input ), CARRAY_FLOAT_WRAPPER ( result ),
96
- vectorSize );
97
- break ;
127
+ params.Counters ->Synchronise ();
128
+
129
+ for ( int i = 0 ; i < params.TestCount ; ++i ) {
130
+ switch ( params.Function ) {
131
+ case 0 : MathEngine ().VectorCopy ( result, input, vectorSize ); break ;
132
+ case 1 : MathEngine ().VectorFill ( result, vectorSize, mulVal ); break ;
133
+ case 2 : MathEngine ().VectorAdd ( input, second, result, vectorSize ); break ;
134
+ case 3 : MathEngine ().VectorAddValue ( input, result, vectorSize, mulVal ); break ;
135
+ case 4 : MathEngine ().VectorMultiply ( input, second, vectorSize, mulVal ); break ;
136
+ case 5 : MathEngine ().VectorEltwiseMultiply ( input, second, result, vectorSize ); break ;
137
+ case 6 : MathEngine ().VectorEltwiseMultiplyAdd ( input, second, result, vectorSize ); break ;
138
+ case 7 : MathEngine ().VectorReLU ( input, result, vectorSize, zeroVal ); break ; // Threshold == 0
139
+ case 8 : MathEngine ().VectorReLU ( input, result, vectorSize, mulVal ); break ; // Threshold > 0
140
+ case 9 : MathEngine ().VectorHSwish ( input, result, vectorSize ); break ;
98
141
default :
99
142
ASSERT_EXPR ( false );
100
143
}
101
144
}
102
145
103
- counters.Synchronise ();
104
- const double time = double ( counters[0 ].Value ) / 1000000 / testCount; // average time in milliseconds
105
-
106
- GTEST_LOG_ ( INFO ) << vectorFunctionsNames[function] << " , " << time;
107
- fout << vectorFunctionsNames[function] << " ," << time << " \n " ;
146
+ params.Counters ->Synchronise ();
147
+ const double time = double ( ( *params.Counters )[0 ].Value ) / 1000000 / params.TestCount ; // average time in milliseconds
148
+ params.FOut << time << " ," ;
149
+ return time;
108
150
}
109
151
110
152
} // namespace NeoMLTest
@@ -129,34 +171,34 @@ INSTANTIATE_TEST_CASE_P( CMathEngineVectorBenchmarkTestInstantiation, CMathEngin
129
171
" VectorValues = (-10..10);"
130
172
),
131
173
CTestParams(
132
- " TestCount = 100000 ;"
174
+ " TestCount = 1000 ;"
133
175
" RepeatCount = 10;"
134
- " VectorSize = 11796480 ;"
176
+ " VectorSize = 1179648 ;"
135
177
" VectorValues = (-1..1);"
136
178
)
137
179
)
138
180
);
139
181
140
182
TEST_P ( CMathEngineVectorBenchmarkTest, DISABLED_Random )
141
183
{
142
- CTestParams params = GetParam ();
143
-
144
- const int testCount = params.GetValue <int >( " TestCount" );
145
- const int repeatCount = params.GetValue <int >( " RepeatCount" );
146
- const int vectorSize = params.GetValue <int >( " VectorSize" );
147
- const CInterval vectorValuesInterval = params.GetInterval ( " VectorValues" );
184
+ CTestParams testParams = GetParam ();
185
+ const int testCount = testParams.GetValue <int >( " TestCount" );
186
+ const int repeatCount = testParams.GetValue <int >( " RepeatCount" );
187
+ const int vectorSize = testParams.GetValue <int >( " VectorSize" );
188
+ const CInterval valuesInterval = testParams.GetInterval ( " VectorValues" );
148
189
149
- IPerformanceCounters* counters = MathEngine ().CreatePerformanceCounters ();
150
- std::ofstream fout ( " VectorBenchmarkTest.csv" , std::ios::app );
151
- fout << " ---------------------------\n " ;
152
-
153
- vectorBenchmark ( /* warm-up*/ -1 , testCount, vectorSize, vectorValuesInterval, 282 , *counters, fout);
190
+ VectorBenchmarkParams params ( /* warm-up*/ -1 , testCount, vectorSize, valuesInterval, 282 );
191
+ vectorBenchmark ( params );
154
192
155
193
for ( int function = 0 ; function < 10 ; ++function ) {
194
+ params.FOut << std::endl << vectorFunctionsNames[function] << " ," ;
195
+
196
+ double timeSum = 0 ;
156
197
for ( int test = 0 ; test < repeatCount; ++test ) {
157
198
const int seed = 282 + test * 10000 + test % 3 ;
158
- vectorBenchmark ( function, testCount, vectorSize, vectorValuesInterval, seed, *counters, fout );
199
+ params.SetNextSeedForFunction ( function, seed );
200
+ timeSum += vectorBenchmark ( params );
159
201
}
202
+ GTEST_LOG_ ( INFO ) << vectorFunctionsNames[function] << " \t " << timeSum;
160
203
}
161
- delete counters;
162
204
}
0 commit comments