Skip to content

Commit bf1258f

Browse files
committed
[MetalMathEngine] Fix compilation
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 875d164 commit bf1258f

File tree

7 files changed

+408
-456
lines changed

7 files changed

+408
-456
lines changed

NeoMathEngine/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ option(NeoMathEngine_BUILD_SHARED "Build NeoMathEngine as shared library." ON)
3232
option(NeoMathEngine_ENABLE_VULKAN "Enable to build vulkan backend" ON)
3333

3434
# Enable to build metal backend
35-
option(NeoMathEngine_ENABLE_METAL "Enable to build metal backend" OFF)
35+
option(NeoMathEngine_ENABLE_METAL "Enable to build metal backend" ON)
3636

3737
# Install NeoMathEngine
3838
if(DEFINED NeoML_INSTALL)

NeoMathEngine/src/GPU/Metal/MetalMathEngine.h

Lines changed: 262 additions & 306 deletions
Large diffs are not rendered by default.

NeoMathEngine/src/GPU/Metal/MetalMathEngine.mm

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,12 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info )
6464

6565
//----------------------------------------------------------------------------------------------------------------------------
6666

67-
// Not using STL in headers
68-
class CMutex : public std::mutex {
69-
};
70-
71-
//----------------------------------------------------------------------------------------------------------------------------
72-
7367
const int MetalMemoryAlignment = 16;
7468

7569
CMetalMathEngine::CMetalMathEngine( size_t memoryLimit ) :
7670
queue( new CMetalCommandQueue() ),
7771
memoryPool( new CMemoryPool( MIN( memoryLimit == 0 ? SIZE_MAX : memoryLimit, defineMemoryLimit() ), this, false ) ),
78-
deviceStackAllocator( new CDeviceStackAllocator( *memoryPool, MetalMemoryAlignment ) ),
79-
mutex( new CMutex() )
72+
deviceStackAllocator( new CDeviceStackAllocator( *memoryPool, MetalMemoryAlignment ) )
8073
{
8174
ASSERT_EXPR( queue->Create() );
8275
}
@@ -87,31 +80,31 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info )
8780

8881
void CMetalMathEngine::SetReuseMemoryMode( bool enable )
8982
{
90-
std::lock_guard<CMutex> lock( *mutex );
83+
std::lock_guard<std::mutex> lock( mutex );
9184
memoryPool->SetReuseMemoryMode( enable );
9285
}
9386

9487
bool CMetalMathEngine::GetReuseMemoryMode() const
9588
{
96-
std::lock_guard<CMutex> lock( *mutex );
89+
std::lock_guard<std::mutex> lock( mutex );
9790
return memoryPool->GetReuseMemoryMode();
9891
}
9992

10093
void CMetalMathEngine::SetThreadBufferMemoryThreshold( size_t threshold )
10194
{
102-
std::lock_guard<CMutex> lock( *mutex );
95+
std::lock_guard<std::mutex> lock( mutex );
10396
memoryPool->SetThreadBufferMemoryThreshold( threshold );
10497
}
10598

10699
size_t CMetalMathEngine::GetThreadBufferMemoryThreshold() const
107100
{
108-
std::lock_guard<CMutex> lock( *mutex );
101+
std::lock_guard<std::mutex> lock( mutex );
109102
return memoryPool->GetThreadBufferMemoryThreshold();
110103
}
111104

112105
CMemoryHandle CMetalMathEngine::HeapAlloc( size_t size )
113106
{
114-
std::lock_guard<CMutex> lock( *mutex );
107+
std::lock_guard<std::mutex> lock( mutex );
115108
CMemoryHandle result = memoryPool->Alloc( size );
116109
if( result.IsNull() ) {
117110
THROW_MEMORY_EXCEPTION;
@@ -124,13 +117,13 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info )
124117
{
125118
ASSERT_EXPR( handle.GetMathEngine() == this );
126119

127-
std::lock_guard<CMutex> lock( *mutex );
120+
std::lock_guard<std::mutex> lock( mutex );
128121
return memoryPool->Free( handle );
129122
}
130123

131124
CMemoryHandle CMetalMathEngine::StackAlloc( size_t size )
132125
{
133-
std::lock_guard<CMutex> lock( *mutex );
126+
std::lock_guard<std::mutex> lock( mutex );
134127
CMemoryHandle result = deviceStackAllocator->Alloc( size );
135128
if( result.IsNull() ) {
136129
THROW_MEMORY_EXCEPTION;
@@ -140,43 +133,43 @@ bool LoadMetalEngineInfo( CMathEngineInfo& info )
140133

141134
void CMetalMathEngine::StackFree( const CMemoryHandle& ptr )
142135
{
143-
std::lock_guard<CMutex> lock( *mutex );
136+
std::lock_guard<std::mutex> lock( mutex );
144137
deviceStackAllocator->Free( ptr );
145138
}
146139

147140
size_t CMetalMathEngine::GetFreeMemorySize() const
148141
{
149-
std::lock_guard<CMutex> lock( *mutex );
142+
std::lock_guard<std::mutex> lock( mutex );
150143
return memoryPool->GetFreeMemorySize();
151144
}
152145

153146
size_t CMetalMathEngine::GetPeakMemoryUsage() const
154147
{
155-
std::lock_guard<CMutex> lock( *mutex );
148+
std::lock_guard<std::mutex> lock( mutex );
156149
return memoryPool->GetPeakMemoryUsage();
157150
}
158151

159152
void CMetalMathEngine::ResetPeakMemoryUsage()
160153
{
161-
std::lock_guard<CMutex> lock( *mutex );
154+
std::lock_guard<std::mutex> lock( mutex );
162155
memoryPool->ResetPeakMemoryUsage();
163156
}
164157

165158
size_t CMetalMathEngine::GetCurrentMemoryUsage() const
166159
{
167-
std::lock_guard<CMutex> lock( *mutex );
160+
std::lock_guard<std::mutex> lock( mutex );
168161
return memoryPool->GetCurrentMemoryUsage();
169162
}
170163

171164
size_t CMetalMathEngine::GetMemoryInPools() const
172165
{
173-
std::lock_guard<CMutex> lock( *mutex );
166+
std::lock_guard<std::mutex> lock( mutex );
174167
return memoryPool->GetMemoryInPools();
175168
}
176169

177170
void CMetalMathEngine::CleanUp()
178171
{
179-
std::lock_guard<CMutex> lock( *mutex );
172+
std::lock_guard<std::mutex> lock( mutex );
180173
deviceStackAllocator->CleanUp();
181174
memoryPool->CleanUp();
182175
}

NeoMathEngine/src/GPU/Metal/MetalMathEngineBlas.mm

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -198,13 +198,15 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelCopyFloatIndicesFloatDa
198198
}
199199

200200
void CMetalMathEngine::VectorMultichannelLookupAndAddToTable(int batchSize, int channelCount, const CConstFloatHandle& inputHandle,
201-
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, const CConstFloatHandle& multHandle,
201+
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, CFloatParam mult,
202202
const CConstFloatHandle& matrixHandle, int outputChannelsCount)
203203
{
204204
ASSERT_EXPR( inputHandle.GetMathEngine() == this );
205-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
206205
ASSERT_EXPR( matrixHandle.GetMathEngine() == this );
207206

207+
CFloatHandleStackVar multHandle( *this );
208+
multHandle.SetValue( mult );
209+
208210
int outputChannel = 0;
209211
for( int i = 0; i < lookupCount; ++i ) {
210212
C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableFloat", 1, 1, lookupDimensions[i].VectorCount, lookupDimensions[i].VectorSize );
@@ -226,13 +228,15 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelCopyFloatIndicesFloatDa
226228
}
227229

228230
void CMetalMathEngine::VectorMultichannelLookupAndAddToTable(int batchSize, int channelCount, const CConstIntHandle& inputHandle,
229-
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, const CConstFloatHandle& multHandle,
231+
const CFloatHandle* lookupHandles, const CLookupDimension* lookupDimensions, int lookupCount, CFloatParam mult,
230232
const CConstFloatHandle& matrixHandle, int outputChannelsCount)
231233
{
232234
ASSERT_EXPR( inputHandle.GetMathEngine() == this );
233-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
234235
ASSERT_EXPR( matrixHandle.GetMathEngine() == this );
235236

237+
CFloatHandleStackVar multHandle( *this );
238+
multHandle.SetValue( mult );
239+
236240
int outputChannel = 0;
237241
for( int i = 0; i < lookupCount; ++i ) {
238242
C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelLookupAndAddToTableInt", 1, 1, lookupDimensions[i].VectorCount, lookupDimensions[i].VectorSize );
@@ -369,12 +373,14 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelCopyFloatIndicesFloatDa
369373
}
370374

371375
void CMetalMathEngine::VectorMultiplyAndAdd(const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
372-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle)
376+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult)
373377
{
374378
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
375379
ASSERT_EXPR( secondHandle.GetMathEngine() == this );
376380
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
377-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
381+
382+
CFloatHandleStackVar multHandle( *this );
383+
multHandle.SetValue( mult );
378384

379385
C1DKernel kernel( *queue, "vectorKernelVectorMultiplyAndAdd", 1, vectorSize );
380386
kernel.SetParam( firstHandle, 0 );
@@ -386,12 +392,14 @@ C2DKernel kernel( *queue, "matrixKernelBatchVectorChannelCopyFloatIndicesFloatDa
386392
}
387393

388394
void CMetalMathEngine::VectorMultiplyAndSub(const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
389-
const CFloatHandle& resultHandle, int vectorSize, const CConstFloatHandle& multHandle)
395+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult)
390396
{
391397
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
392398
ASSERT_EXPR( secondHandle.GetMathEngine() == this );
393399
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
394-
ASSERT_EXPR( multHandle.GetMathEngine() == this );
400+
401+
CFloatHandleStackVar multHandle( *this );
402+
multHandle.SetValue( mult );
395403

396404
C1DKernel kernel( *queue, "vectorKernelVectorMultiplyAndSub", 1, vectorSize );
397405
kernel.SetParam( firstHandle, 0 );
@@ -950,7 +958,7 @@ C2DKernel kernel( *queue, "matrixKernelMultiplyMatrixByTransposedMatrixThread4x4
950958
}
951959

952960
void CMetalMathEngine::BatchMultiplyMatrixByDiagMatrix( int batchSize, const CConstFloatHandle& firstHandle, int height,
953-
int width, int firstMatrixOffset, const CConstFloatHandle& secondHandle, int secondMatrixOffset,
961+
int width, int /*firstMatrixOffset*/, const CConstFloatHandle& secondHandle, int /*secondMatrixOffset*/,
954962
const CFloatHandle& resultHandle, int )
955963
{
956964
ASSERT_EXPR( firstHandle.GetMathEngine() == this );

NeoMathEngine/src/GPU/Metal/MetalMathEngineDnn.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -690,7 +690,7 @@ C2DKernel kernel( *queue,
690690

691691
void CMetalMathEngine::MobileNetV2Block( const CBlobDesc&, const CBlobDesc&,
692692
const CRowwiseOperationDesc&, const CChannelwiseConvolutionDesc&,
693-
const CConstFloatHandle&, const CFloatHandle& ) override
693+
const CConstFloatHandle&, const CFloatHandle& )
694694
{
695695
ASSERT_EXPR( false );
696696
}

NeoMathEngine/src/GPU/Metal/MetalMathEngineDnnPoolings.mm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2023 ABBYY
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -39,7 +39,7 @@
3939
}
4040

4141
void CMetalMathEngine::BlobMaxPooling( const CMaxPoolingDesc& poolingDesc,
42-
const CFConstloatHandle& sourceData, const CIntHandle* maxIndicesData, const CFloatHandle& resultData )
42+
const CConstFloatHandle& sourceData, const CIntHandle* maxIndicesData, const CFloatHandle& resultData )
4343
{
4444
ASSERT_EXPR( sourceData.GetMathEngine() == this );
4545
ASSERT_EXPR( maxIndicesData == 0 );
@@ -219,7 +219,7 @@
219219
}
220220

221221
void CMetalMathEngine::Blob3dMaxPooling( const C3dMaxPoolingDesc& poolingDesc, const CConstFloatHandle& sourceData,
222-
const CConstIntHandle* maxIndicesData, const CFloatHandle& resultData )
222+
const CIntHandle* maxIndicesData, const CFloatHandle& resultData )
223223
{
224224
ASSERT_EXPR( sourceData.GetMathEngine() == this );
225225
ASSERT_EXPR( maxIndicesData == 0 || maxIndicesData->GetMathEngine() == this );

0 commit comments

Comments
 (0)