@@ -80,31 +80,26 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
80
80
void VectorEqual ( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
81
81
const CFloatHandle& resultHandle, int vectorSize ) override ;
82
82
void VectorEqualValue ( const CConstIntHandle& firstHandle,
83
- const CFloatHandle& resultHandle, int vectorSize, const CConstIntHandle& valueHandle ) override ;
84
- void VectorMax ( const CConstFloatHandle& firstHandle, float secondValue, const CFloatHandle& resultHandle,
83
+ const CFloatHandle& resultHandle, int vectorSize, CIntParam value ) override ;
84
+ void VectorMax ( const CConstFloatHandle& firstHandle, CFloatParam secondValue, const CFloatHandle& resultHandle,
85
85
int vectorSize ) override ;
86
- void VectorMaxDiff ( const CConstFloatHandle& firstHandle, float secondValue, const CFloatHandle& gradHandle,
86
+ void VectorMaxDiff ( const CConstFloatHandle& firstHandle, CFloatParam secondValue, const CFloatHandle& gradHandle,
87
87
int gradHeight, int gradWidth ) override ;
88
88
void VectorELU ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
89
- int vectorSize, const CConstFloatHandle& alpha ) override ;
89
+ int vectorSize, CFloatParam alpha ) override ;
90
90
void VectorELUDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
91
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& alpha ) override ;
91
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam alpha ) override ;
92
92
void VectorELUDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
93
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& alpha ) override ;
93
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam alpha ) override ;
94
94
void VectorReLU ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
95
- const CConstFloatHandle& upperThresholdHandle ) override ;
95
+ CFloatParam upperThreshold ) override ;
96
96
void VectorReLUDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
97
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle ) override ;
98
- void VectorReLUDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
99
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle ) override ;
97
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam upperThreshold ) override ;
100
98
void VectorLeakyReLU ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
101
- int vectorSize, const CConstFloatHandle& alpha ) override ;
99
+ int vectorSize, CFloatParam alpha ) override ;
102
100
void VectorLeakyReLUDiff ( const CConstFloatHandle& firstHandle,
103
101
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
104
- int vectorSize, const CConstFloatHandle& alpha ) override ;
105
- void VectorLeakyReLUDiffOp ( const CConstFloatHandle& firstHandle,
106
- const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
107
- int vectorSize, const CConstFloatHandle& alpha ) override ;
102
+ int vectorSize, CFloatParam alpha ) override ;
108
103
void VectorHSwish ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
109
104
int vectorSize ) override ;
110
105
void VectorHSwishDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
@@ -129,16 +124,12 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
129
124
void VectorHardTanh ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
130
125
void VectorHardTanhDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
131
126
const CFloatHandle& resultHandle, int vectorSize ) override ;
132
- void VectorHardTanhDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
133
- const CFloatHandle& resultHandle, int vectorSize ) override ;
134
127
void VectorHardSigmoid ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
135
- const CConstFloatHandle& slopeHandle, const CConstFloatHandle& biasHandle ) override ;
128
+ CFloatParam slope, CFloatParam bias ) override ;
136
129
void VectorHardSigmoidDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
137
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& slopeHandle,
138
- const CConstFloatHandle& biasHandle ) override ;
130
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam slope, CFloatParam bias ) override ;
139
131
void VectorHardSigmoidDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
140
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& slopeHandle,
141
- const CConstFloatHandle& biasHandle ) override ;
132
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam slope, CFloatParam bias ) override ;
142
133
void VectorNeg ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
143
134
void VectorExp ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
144
135
void VectorLog ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
@@ -148,15 +139,15 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
148
139
void VectorNegLog ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
149
140
void VectorErf ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
150
141
void VectorBernulliKLDerivative ( const CConstFloatHandle& estimationHandle,
151
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& target ) override ;
142
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam target ) override ;
152
143
void VectorAdd ( const CConstFloatHandle& firstHandle,
153
144
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
154
145
void VectorAdd ( const CConstIntHandle& firstHandle,
155
146
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override ;
156
147
void VectorAddValue ( const CConstFloatHandle& firstHandle,
157
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& addition ) override ;
148
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam value ) override ;
158
149
void VectorAddValue ( const CConstIntHandle& firstHandle,
159
- const CIntHandle& resultHandle, int vectorSize, const CConstIntHandle& addition ) override ;
150
+ const CIntHandle& resultHandle, int vectorSize, CIntParam value ) override ;
160
151
void VectorSub ( const CConstIntHandle& firstHandle,
161
152
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override ;
162
153
void VectorSub ( const CConstFloatHandle& firstHandle,
@@ -166,15 +157,15 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
166
157
void VectorSub ( float first,
167
158
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
168
159
void VectorMultiplyAndAdd ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
169
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle ) override ;
160
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override ;
170
161
void VectorMultiplyAndSub ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
171
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle ) override ;
162
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override ;
172
163
void VectorMultiply ( const CConstFloatHandle& firstHandle,
173
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multiplierHandle ) override ;
164
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override ;
174
165
void VectorMultiply ( const CConstIntHandle& firstHandle,
175
- const CIntHandle& resultHandle, int vectorSize, const CConstIntHandle& multiplierHandle ) override ;
166
+ const CIntHandle& resultHandle, int vectorSize, CIntParam mult ) override ;
176
167
void VectorNegMultiply ( const CConstFloatHandle& firstHandle,
177
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multiplierHandle ) override ;
168
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override ;
178
169
void VectorEltwiseMultiply ( const CConstIntHandle& firstHandle,
179
170
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override ;
180
171
void VectorEltwiseMultiply ( const CConstFloatHandle& firstHandle,
@@ -192,10 +183,10 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
192
183
void VectorSqrt ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
193
184
void VectorInv ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
194
185
void VectorMinMax ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
195
- const CConstFloatHandle& minHandle, const CConstFloatHandle& maxHandle ) override ;
186
+ CFloatParam min, CFloatParam max ) override ;
196
187
void VectorMinMaxDiff ( const CConstFloatHandle& sourceGradHandle, int gradHeight, int gradWidth,
197
188
const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
198
- const CConstFloatHandle& minHandle, const CConstFloatHandle& maxHandle ) override ;
189
+ CFloatParam min, CFloatParam max ) override ;
199
190
void VectorSigmoid ( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override ;
200
191
void VectorSigmoidDiff ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
201
192
const CFloatHandle& resultHandle, int vectorSize ) override ;
@@ -213,8 +204,7 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
213
204
void VectorPowerDiffOp ( float exponent, const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
214
205
const CFloatHandle& resultHandle, int vectorSize ) override ;
215
206
void VectorL1DiffAdd ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
216
- const CFloatHandle& resultHandle, int vectorSize,
217
- const CConstFloatHandle& hubertThresholdHandle, const CConstFloatHandle& multHandle ) override ;
207
+ const CFloatHandle& resultHandle, int vectorSize, CFloatParam hubertThreshold, CFloatParam mult ) override ;
218
208
void VectorDotProduct ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle, int vectorSize,
219
209
const CFloatHandle& resultHandle ) override ;
220
210
void VectorEltwiseNot ( const CConstIntHandle& firstHandle, const CIntHandle& resultHandle, int vectorSize ) override ;
@@ -310,10 +300,10 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
310
300
const CIntHandle& outputHandle, int outputChannels ) override ;
311
301
void VectorMultichannelLookupAndAddToTable ( int batchSize, int channelCount, const CConstFloatHandle& inputHandle,
312
302
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount,
313
- const CConstFloatHandle& multHandle , const CConstFloatHandle& matrixHandle, int outputChannels ) override ;
303
+ CFloatParam mult , const CConstFloatHandle& matrixHandle, int outputChannels ) override ;
314
304
void VectorMultichannelLookupAndAddToTable ( int batchSize, int channelCount, const CConstIntHandle& inputHandle,
315
305
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount,
316
- const CConstFloatHandle& multHandle , const CConstFloatHandle& matrixHandle, int outputChannels ) override ;
306
+ CFloatParam mult , const CConstFloatHandle& matrixHandle, int outputChannels ) override ;
317
307
void LookupAndSum ( const CConstIntHandle& indicesHandle, int batchSize, int indexCount,
318
308
const CConstFloatHandle& tableHandle, int vectorSize, const CFloatHandle& result ) override ;
319
309
void LookupAndAddToTable ( const CConstIntHandle& indicesHandle, int batchSize, int indexCount,
@@ -616,8 +606,8 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
616
606
617
607
IPerformanceCounters* CreatePerformanceCounters ( bool ) const override { return new CPerformanceCountersDefault (); }
618
608
// For Distributed only
619
- void AllReduce ( const CFloatHandle& /* handle*/ , int /* size*/ ) override {};
620
- void Broadcast ( const CFloatHandle& /* handle*/ , int /* size*/ , int /* root*/ ) override {};
609
+ void AllReduce ( const CFloatHandle& /* handle*/ , int /* size*/ ) override {}
610
+ void Broadcast ( const CFloatHandle& /* handle*/ , int /* size*/ , int /* root*/ ) override {}
621
611
622
612
protected:
623
613
// IRawMemoryManager interface methods
@@ -640,26 +630,8 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
640
630
const CBlobDesc& to, const CFloatHandle& toData );
641
631
void blobSplitByDim ( int dimNum, const CBlobDesc& from, const CConstFloatHandle& fromData,
642
632
const CBlobDesc* to, const CFloatHandle* toData, int toCount );
643
- };
644
633
645
- inline void CMetalMathEngine::VectorReLUDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
646
- const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle )
647
- {
648
- VectorReLUDiff ( firstHandle, secondHandle, resultHandle, vectorSize, upperThresholdHandle );
649
- }
650
-
651
- inline void CMetalMathEngine::VectorLeakyReLUDiffOp ( const CConstFloatHandle& firstHandle,
652
- const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
653
- int vectorSize, const CConstFloatHandle& alpha )
654
- {
655
- VectorLeakyReLUDiff ( firstHandle, secondHandle, resultHandle, vectorSize, alpha );
656
- }
657
-
658
- inline void CMetalMathEngine::VectorHardTanhDiffOp ( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
659
- const CFloatHandle& resultHandle, int vectorSize )
660
- {
661
- VectorHardTanhDiff ( firstHandle, secondHandle, resultHandle, vectorSize );
662
- }
634
+ };
663
635
664
636
} // namespace NeoML
665
637
0 commit comments