Skip to content

Commit 41e9ca4

Browse files
committed
[MetalMathEngine] Add CScalarParameter
Signed-off-by: Kirill Golikov <[email protected]>
1 parent a1d20fb commit 41e9ca4

File tree

4 files changed

+180
-155
lines changed

4 files changed

+180
-155
lines changed

Build/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CMAKE_WORKING_DIR=$ROOT/_cmake_working_dir/NeoML.${FINE_CMAKE_BUILD_TARGET}.${FI
99
pushd ${CMAKE_WORKING_DIR}
1010

1111
if [[ $FINE_CMAKE_BUILD_TARGET == "IOS" ]]; then
12-
cmake -G Xcode -DUSE_FINE_OBJECTS=ON -DCMAKE_TOOLCHAIN_FILE=${ROOT}/NeoML/cmake/ios.toolchain.cmake -DIOS_ARCH=${FINE_CMAKE_BUILD_ARCH} ${ROOT}/NeoML/NeoML
12+
cmake -G Xcode -DUSE_FINE_OBJECTS=ON -DCMAKE_TOOLCHAIN_FILE=${ROOT}/NeoML/cmake/ios.toolchain.cmake -DIOS_ARCH=${FINE_CMAKE_BUILD_ARCH} ${ROOT}/NeoML/NeoML -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_CONFIG}
1313
elif [[ $FINE_CMAKE_BUILD_TARGET == "Linux" && $FINE_CMAKE_BUILD_ARCH == "x86" ]]; then
1414
cmake -DUSE_FINE_OBJECTS=ON -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_CONFIG} -DCMAKE_CXX_FLAGS=-m32 -DCMAKE_C_FLAGS=-m32 ${ROOT}/NeoML/NeoML
1515
elif [[ $FINE_CMAKE_BUILD_TARGET == "Linux" ]]; then

NeoMathEngine/src/GPU/Metal/MetalMathEngine.h

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -80,31 +80,26 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
8080
void VectorEqual( const CConstIntHandle& firstHandle, const CConstIntHandle& secondHandle,
8181
const CFloatHandle& resultHandle, int vectorSize ) override;
8282
void VectorEqualValue( const CConstIntHandle& firstHandle,
83-
const CFloatHandle& resultHandle, int vectorSize, const CConstIntHandle& valueHandle ) override;
84-
void VectorMax( const CConstFloatHandle& firstHandle, float secondValue, const CFloatHandle& resultHandle,
83+
const CFloatHandle& resultHandle, int vectorSize, CIntParam value ) override;
84+
void VectorMax( const CConstFloatHandle& firstHandle, CFloatParam secondValue, const CFloatHandle& resultHandle,
8585
int vectorSize ) override;
86-
void VectorMaxDiff( const CConstFloatHandle& firstHandle, float secondValue, const CFloatHandle& gradHandle,
86+
void VectorMaxDiff( const CConstFloatHandle& firstHandle, CFloatParam secondValue, const CFloatHandle& gradHandle,
8787
int gradHeight, int gradWidth ) override;
8888
void VectorELU( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
89-
int vectorSize, const CConstFloatHandle& alpha ) override;
89+
int vectorSize, CFloatParam alpha ) override;
9090
void VectorELUDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
91-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& alpha ) override;
91+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam alpha ) override;
9292
void VectorELUDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
93-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& alpha ) override;
93+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam alpha ) override;
9494
void VectorReLU( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
95-
const CConstFloatHandle& upperThresholdHandle ) override;
95+
CFloatParam upperThreshold ) override;
9696
void VectorReLUDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
97-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle ) override;
98-
void VectorReLUDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
99-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle ) override;
97+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam upperThreshold ) override;
10098
void VectorLeakyReLU( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
101-
int vectorSize, const CConstFloatHandle& alpha ) override;
99+
int vectorSize, CFloatParam alpha ) override;
102100
void VectorLeakyReLUDiff( const CConstFloatHandle& firstHandle,
103101
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
104-
int vectorSize, const CConstFloatHandle& alpha ) override;
105-
void VectorLeakyReLUDiffOp( const CConstFloatHandle& firstHandle,
106-
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
107-
int vectorSize, const CConstFloatHandle& alpha ) override;
102+
int vectorSize, CFloatParam alpha ) override;
108103
void VectorHSwish( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
109104
int vectorSize ) override;
110105
void VectorHSwishDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
@@ -129,16 +124,12 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
129124
void VectorHardTanh( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
130125
void VectorHardTanhDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
131126
const CFloatHandle& resultHandle, int vectorSize ) override;
132-
void VectorHardTanhDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
133-
const CFloatHandle& resultHandle, int vectorSize ) override;
134127
void VectorHardSigmoid( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
135-
const CConstFloatHandle& slopeHandle, const CConstFloatHandle& biasHandle ) override;
128+
CFloatParam slope, CFloatParam bias ) override;
136129
void VectorHardSigmoidDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
137-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& slopeHandle,
138-
const CConstFloatHandle& biasHandle ) override;
130+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam slope, CFloatParam bias ) override;
139131
void VectorHardSigmoidDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
140-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& slopeHandle,
141-
const CConstFloatHandle& biasHandle ) override;
132+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam slope, CFloatParam bias ) override;
142133
void VectorNeg( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
143134
void VectorExp( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
144135
void VectorLog( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
@@ -148,15 +139,15 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
148139
void VectorNegLog( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
149140
void VectorErf( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
150141
void VectorBernulliKLDerivative( const CConstFloatHandle& estimationHandle,
151-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& target ) override;
142+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam target ) override;
152143
void VectorAdd( const CConstFloatHandle& firstHandle,
153144
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
154145
void VectorAdd( const CConstIntHandle& firstHandle,
155146
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override;
156147
void VectorAddValue( const CConstFloatHandle& firstHandle,
157-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& addition ) override;
148+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam value ) override;
158149
void VectorAddValue( const CConstIntHandle& firstHandle,
159-
const CIntHandle& resultHandle, int vectorSize, const CConstIntHandle& addition ) override;
150+
const CIntHandle& resultHandle, int vectorSize, CIntParam value ) override;
160151
void VectorSub( const CConstIntHandle& firstHandle,
161152
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override;
162153
void VectorSub( const CConstFloatHandle& firstHandle,
@@ -166,15 +157,15 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
166157
void VectorSub( float first,
167158
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
168159
void VectorMultiplyAndAdd( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
169-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle ) override;
160+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override;
170161
void VectorMultiplyAndSub( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
171-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle ) override;
162+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override;
172163
void VectorMultiply( const CConstFloatHandle& firstHandle,
173-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multiplierHandle ) override;
164+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override;
174165
void VectorMultiply( const CConstIntHandle& firstHandle,
175-
const CIntHandle& resultHandle, int vectorSize, const CConstIntHandle& multiplierHandle ) override;
166+
const CIntHandle& resultHandle, int vectorSize, CIntParam mult ) override;
176167
void VectorNegMultiply( const CConstFloatHandle& firstHandle,
177-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multiplierHandle ) override;
168+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult ) override;
178169
void VectorEltwiseMultiply( const CConstIntHandle& firstHandle,
179170
const CConstIntHandle& secondHandle, const CIntHandle& resultHandle, int vectorSize ) override;
180171
void VectorEltwiseMultiply( const CConstFloatHandle& firstHandle,
@@ -192,10 +183,10 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
192183
void VectorSqrt( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
193184
void VectorInv( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
194185
void VectorMinMax( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize,
195-
const CConstFloatHandle& minHandle, const CConstFloatHandle& maxHandle ) override;
186+
CFloatParam min, CFloatParam max ) override;
196187
void VectorMinMaxDiff( const CConstFloatHandle& sourceGradHandle, int gradHeight, int gradWidth,
197188
const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle,
198-
const CConstFloatHandle& minHandle, const CConstFloatHandle& maxHandle ) override;
189+
CFloatParam min, CFloatParam max ) override;
199190
void VectorSigmoid( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) override;
200191
void VectorSigmoidDiff( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
201192
const CFloatHandle& resultHandle, int vectorSize ) override;
@@ -213,8 +204,7 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
213204
void VectorPowerDiffOp( float exponent, const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
214205
const CFloatHandle& resultHandle, int vectorSize ) override;
215206
void VectorL1DiffAdd( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
216-
const CFloatHandle& resultHandle, int vectorSize,
217-
const CConstFloatHandle& hubertThresholdHandle, const CConstFloatHandle& multHandle ) override;
207+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam hubertThreshold, CFloatParam mult ) override;
218208
void VectorDotProduct( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle, int vectorSize,
219209
const CFloatHandle& resultHandle ) override;
220210
void VectorEltwiseNot( const CConstIntHandle& firstHandle, const CIntHandle& resultHandle, int vectorSize ) override;
@@ -310,10 +300,10 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
310300
const CIntHandle& outputHandle, int outputChannels ) override;
311301
void VectorMultichannelLookupAndAddToTable( int batchSize, int channelCount, const CConstFloatHandle& inputHandle,
312302
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount,
313-
const CConstFloatHandle& multHandle, const CConstFloatHandle& matrixHandle, int outputChannels ) override;
303+
CFloatParam mult, const CConstFloatHandle& matrixHandle, int outputChannels ) override;
314304
void VectorMultichannelLookupAndAddToTable( int batchSize, int channelCount, const CConstIntHandle& inputHandle,
315305
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount,
316-
const CConstFloatHandle& multHandle, const CConstFloatHandle& matrixHandle, int outputChannels ) override;
306+
CFloatParam mult, const CConstFloatHandle& matrixHandle, int outputChannels ) override;
317307
void LookupAndSum( const CConstIntHandle& indicesHandle, int batchSize, int indexCount,
318308
const CConstFloatHandle& tableHandle, int vectorSize, const CFloatHandle& result ) override;
319309
void LookupAndAddToTable( const CConstIntHandle& indicesHandle, int batchSize, int indexCount,
@@ -616,8 +606,8 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
616606

617607
IPerformanceCounters* CreatePerformanceCounters( bool ) const override { return new CPerformanceCountersDefault(); }
618608
// For Distributed only
619-
void AllReduce( const CFloatHandle& /*handle*/, int /*size*/ ) override {};
620-
void Broadcast( const CFloatHandle& /*handle*/, int /*size*/, int /*root*/ ) override {};
609+
void AllReduce( const CFloatHandle& /*handle*/, int /*size*/ ) override {}
610+
void Broadcast( const CFloatHandle& /*handle*/, int /*size*/, int /*root*/ ) override {}
621611

622612
protected:
623613
// IRawMemoryManager interface methods
@@ -640,26 +630,8 @@ class CMetalMathEngine : public CMemoryEngineMixin, public IRawMemoryManager {
640630
const CBlobDesc& to, const CFloatHandle& toData );
641631
void blobSplitByDim( int dimNum, const CBlobDesc& from, const CConstFloatHandle& fromData,
642632
const CBlobDesc* to, const CFloatHandle* toData, int toCount );
643-
};
644633

645-
inline void CMetalMathEngine::VectorReLUDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
646-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& upperThresholdHandle )
647-
{
648-
VectorReLUDiff( firstHandle, secondHandle, resultHandle, vectorSize, upperThresholdHandle );
649-
}
650-
651-
inline void CMetalMathEngine::VectorLeakyReLUDiffOp( const CConstFloatHandle& firstHandle,
652-
const CConstFloatHandle& secondHandle, const CFloatHandle& resultHandle,
653-
int vectorSize, const CConstFloatHandle& alpha )
654-
{
655-
VectorLeakyReLUDiff( firstHandle, secondHandle, resultHandle, vectorSize, alpha );
656-
}
657-
658-
inline void CMetalMathEngine::VectorHardTanhDiffOp( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
659-
const CFloatHandle& resultHandle, int vectorSize )
660-
{
661-
VectorHardTanhDiff( firstHandle, secondHandle, resultHandle, vectorSize );
662-
}
634+
};
663635

664636
} // namespace NeoML
665637

NeoMathEngine/src/GPU/Metal/MetalMathEngineBlas.mm

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
@import MetalKit;
2727

2828
namespace NeoML {
29-
29+
3030
// The number of combined values for the vector kernels
3131
static const int VectorCombineCount = 8;
3232

@@ -193,13 +193,15 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelCopyIntIndicesIntData",
193193
}
194194

195195
void CMetalMathEngine::VectorMultichannelLookupAndAddToTable( int batchSize, int channelCount, const CConstFloatHandle& inputHandle,
196-
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, const CConstFloatHandle& multHandle,
196+
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, CFloatParam mult,
197197
const CConstFloatHandle& matrixHandle, int outputChannelsCount )
198198
{
199199
ASSERT_EXPR( inputHandle.GetMathEngine() == this );
200-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
201200
ASSERT_EXPR( matrixHandle.GetMathEngine() == this );
202201

202+
CFloatHandleStackVar multHandle( *this );
203+
multHandle.SetValue( mult );
204+
203205
int outputChannel = 0;
204206
for( int i = 0; i < lookupCount; ++i ) {
205207
C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableFloat",
@@ -222,13 +224,15 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableFloa
222224
}
223225

224226
void CMetalMathEngine::VectorMultichannelLookupAndAddToTable( int batchSize, int channelCount, const CConstIntHandle& inputHandle,
225-
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, const CConstFloatHandle& multHandle,
227+
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, CFloatParam mult,
226228
const CConstFloatHandle& matrixHandle, int outputChannelsCount )
227229
{
228230
ASSERT_EXPR( inputHandle.GetMathEngine() == this );
229-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
230231
ASSERT_EXPR( matrixHandle.GetMathEngine() == this );
231232

233+
CFloatHandleStackVar multHandle( *this );
234+
multHandle.SetValue( mult );
235+
232236
int outputChannel = 0;
233237
for( int i = 0; i < lookupCount; ++i ) {
234238
C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableInt",
@@ -366,12 +370,14 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableInt"
366370
}
367371

368372
void CMetalMathEngine::VectorMultiplyAndAdd( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
369-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle )
373+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult )
370374
{
371375
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
372376
ASSERT_EXPR( secondHandle.GetMathEngine() == this );
373377
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
374-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
378+
379+
CFloatHandleStackVar multHandle( *this );
380+
multHandle.SetValue( mult );
375381

376382
C1DKernel kernel( *queue, "vectorKernelVectorMultiplyAndAdd", 1, vectorSize );
377383
kernel.SetParam( firstHandle, 0 );
@@ -383,12 +389,14 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableInt"
383389
}
384390

385391
void CMetalMathEngine::VectorMultiplyAndSub( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
386-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle )
392+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult )
387393
{
388394
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
389395
ASSERT_EXPR( secondHandle.GetMathEngine() == this );
390396
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
391-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
397+
398+
CFloatHandleStackVar multHandle( *this );
399+
multHandle.SetValue( mult );
392400

393401
C1DKernel kernel( *queue, "vectorKernelVectorMultiplyAndSub", 1, vectorSize );
394402
kernel.SetParam( firstHandle, 0 );

0 commit comments

Comments
 (0)