Skip to content

Commit 1ff6f53

Browse files
committed
[NeoML] CDnnBlob less memory consumption
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 90d43bc commit 1ff6f53

File tree

5 files changed

+130
-82
lines changed

5 files changed

+130
-82
lines changed

NeoML/Python/src/PyCustomLossLayer.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2021 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -60,16 +60,14 @@ static CArchive& operator >>( CArchive& archive, py::object& obj )
6060

6161
//------------------------------------------------------------------------------------------------------------
6262

63-
class CTempBlob : public CDnnBlob {
63+
// CPyTempDnnBlob does not own the handler data
64+
class CPyTempDnnBlob : public CDnnBlobView {
6465
public:
65-
CTempBlob( IMathEngine& mathEngine, const CConstFloatHandle& data, const CBlobDesc& dataDesc );
66+
CPyTempDnnBlob( IMathEngine& mathEngine, const CConstFloatHandle& data, const CBlobDesc& dataDesc ) :
67+
CDnnBlobView( mathEngine, dataDesc, data )
68+
{}
6669
};
6770

68-
CTempBlob::CTempBlob( IMathEngine& mathEngine, const CConstFloatHandle& data, const CBlobDesc& dataDesc ) :
69-
CDnnBlob( mathEngine, dataDesc, data, false )
70-
{
71-
}
72-
7371
//------------------------------------------------------------------------------------------------------------
7472

7573
class CPythonLossLayer : public CLossLayer {
@@ -89,11 +87,11 @@ class CPythonLossLayer : public CLossLayer {
8987

9088
CPtr<CPyMathEngineOwner> mathEngineOwner = new CPyMathEngineOwner( &MathEngine(), false );
9189

92-
CPtr<const CDnnBlob> dataBlob = new CTempBlob( mathEngineOwner->MathEngine(), data, inputBlobs[0]->GetDesc() );
90+
CPtr<const CDnnBlob> dataBlob = new CPyTempDnnBlob( mathEngineOwner->MathEngine(), data, inputBlobs[0]->GetDesc() );
9391
CPtr<const CDnnBlob> var = tape.Variable( *dataBlob.Ptr() );
9492
CPyBlob dataPyBlob( *mathEngineOwner, const_cast<CDnnBlob*>(var.Ptr()) );
9593

96-
CPtr<CDnnBlob> labelBlob( new CTempBlob( mathEngineOwner->MathEngine(), label, inputBlobs[1]->GetDesc() ) );
94+
CPtr<CDnnBlob> labelBlob( new CPyTempDnnBlob( mathEngineOwner->MathEngine(), label, inputBlobs[1]->GetDesc() ) );
9795
CPyBlob labelPyBlob( *mathEngineOwner, labelBlob );
9896

9997
CPtr<CDnnBlob> value;

NeoML/Python/src/PyDnnBlob.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ static CBlobDesc createBlobDesc( TBlobType type, std::initializer_list<int> dime
115115
return desc;
116116
}
117117

118-
class CPyDnnBlob : public CDnnBlob {
118+
//------------------------------------------------------------------------------------------------------------
119+
120+
// CPyDnnBlob does not own the handler data
121+
class CPyDnnBlob : public CDnnBlobView {
119122
public:
120123
CPyDnnBlob( IMathEngine& mathEngine, TBlobType type, std::initializer_list<int> dimension, py::buffer_info&& _info );
121124
virtual ~CPyDnnBlob();
@@ -128,7 +131,7 @@ class CPyDnnBlob : public CDnnBlob {
128131
};
129132

130133
CPyDnnBlob::CPyDnnBlob( IMathEngine& mathEngine, TBlobType type, std::initializer_list<int> dimension, py::buffer_info&& _info ) :
131-
CDnnBlob( mathEngine, createBlobDesc( type, dimension ), CPyMemoryHandle( &mathEngine, _info.ptr ), false ),
134+
CDnnBlobView( mathEngine, createBlobDesc( type, dimension ), CPyMemoryHandle( &mathEngine, _info.ptr ) ),
132135
info( new py::buffer_info( std::move( _info ) ) )
133136
{
134137
}

NeoML/include/NeoML/Dnn/DnnBlob.h

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ namespace NeoML {
3434

3535
class NEOML_API CDnnBlob : public IObject {
3636
public:
37-
explicit CDnnBlob( IMathEngine& mathEngine );
37+
explicit CDnnBlob( IMathEngine& mathEngine ) : mathEngine( mathEngine ) {}
3838

39-
// Move other's Blob state to this Blob and transfer its data (if dataOwned) to this thread
39+
// Move other's Blob state to this Blob and transfer its data to this thread
4040
CDnnBlob( CDnnBlob&& other );
4141
CDnnBlob& operator=( CDnnBlob&& other );
4242

@@ -65,7 +65,7 @@ class NEOML_API CDnnBlob : public IObject {
6565
static CDnnBlob* CreateBlob(IMathEngine& mathEngine, TBlobType type, const CBlobDesc& pattern);
6666

6767
// Checks if the dimensions of another blob are the same
68-
bool HasEqualDimensions(const CDnnBlob* other) const;
68+
bool HasEqualDimensions( const CDnnBlob* other ) const { return desc.HasEqualDimensions( other->desc ); }
6969

7070
// Gets the blob size along the specified dimension
7171
int DimSize(int d) const { return desc.DimSize(d); }
@@ -143,7 +143,7 @@ class NEOML_API CDnnBlob : public IObject {
143143
// Transfers CDnnBlob data from other thread owner to this thread.
144144
// By default memory underneath each blob is associated with the thread on which its allocation has occurred.
145145
// This method switches this association to the calling thread.
146-
void TransferDataToThisThread();
146+
virtual void TransferDataToThisThread();
147147

148148
// Elementwise adds a blob of the same dimensions
149149
void Add(const CDnnBlob* other);
@@ -165,7 +165,7 @@ class NEOML_API CDnnBlob : public IObject {
165165
// Changes the blob dimensions "names" without moving the data
166166
// In effect, only the blob description is changed
167167
// As the data is unaffected, the total blob size specified by the new descriptor should be the same
168-
void ReinterpretDimensions( const CBlobDesc& newDesc );
168+
virtual void ReinterpretDimensions( const CBlobDesc& newDesc );
169169

170170
// Merges blobs along the given dimension
171171
static void MergeByDim( IMathEngine& mathEngine, TBlobDim d, const CObjectArray<CDnnBlob>& from, const CPtr<CDnnBlob>& to );
@@ -193,31 +193,31 @@ class NEOML_API CDnnBlob : public IObject {
193193

194194
// Gets the pointer to the MathEngine on which the blob was created
195195
IMathEngine& GetMathEngine() const { return mathEngine; }
196-
197196
// Gets the blob descriptor
198197
const CBlobDesc& GetDesc() const { return desc; }
199198
// Gets the type of data in the blob
200199
TBlobType GetDataType() const { return desc.GetDataType(); }
201200

201+
// All methods below are just the interface for the CDnnWindowBlob representation
202+
202203
// Gets the parent blob
203-
CDnnBlob* GetParent() { return parent; }
204-
const CDnnBlob* GetParent() const { return parent; }
204+
virtual CDnnBlob* GetParent() { return nullptr; }
205+
const CDnnBlob* GetParent() const { return const_cast<CDnnBlob*>( this )->GetParent(); }
205206
// Gets the blob that owns the data (and has no parent)
206-
CDnnBlob* GetOwner();
207-
const CDnnBlob* GetOwner() const { return const_cast<CDnnBlob*>(this)->GetOwner(); }
208-
207+
virtual CDnnBlob* GetOwner() { return this; }
208+
const CDnnBlob* GetOwner() const { return const_cast<CDnnBlob*>( this )->GetOwner(); }
209209
// Gets the shift in data relative to the parent blob
210210
// The position in the parent blob is calculated along the BatchLength dimension
211211
// The position equal to N would correspond to a N*BatchWidth*ListSize*Height*Width*Depth*Channels shift in the one-dimensional array
212-
int GetParentPos() const;
213-
void SetParentPos( int pos );
214-
void ShiftParentPos( int shift );
212+
virtual int GetParentPos() const { return 0; }
213+
virtual void SetParentPos( int /*pos*/ ) { NeoAssert( false ); }
214+
virtual void ShiftParentPos( int /*shift*/ ) { NeoAssert( false ); }
215215

216216
protected:
217217
~CDnnBlob() override;
218218

219-
CDnnBlob( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data, bool _dataOwned ) :
220-
mathEngine( _mathEngine ), desc( _desc ), data( _data ), dataOwned( _dataOwned ), parentPos( 0 )
219+
CDnnBlob( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data ) :
220+
mathEngine( _mathEngine ), desc( _desc ), data( _data )
221221
{
222222
NeoAssert( desc.GetDataType() != CT_Invalid );
223223
NeoAssert( &mathEngine == data.GetMathEngine() );
@@ -230,24 +230,81 @@ class NEOML_API CDnnBlob : public IObject {
230230
CBlobDesc desc;
231231
// Pointer to the allocated data storage
232232
CMemoryHandle data;
233-
// Ownership of the `data`, it means that it has full access to write and to free the allocated data storage
234-
// Either `dataOwned` is true and `parent` is 0
235-
// Or `dataOwned` is false and `parent` is pointer to blob that owns the allocated data storage
236-
bool dataOwned;
237-
// Pointer to blob with data for sequential recurent mode or reference dnn's paramBlobs
238-
CPtr<CDnnBlob> parent;
239-
// Offset in `parent` blob for sequential recurent mode, move window by BatchLength of the parent blob
240-
int parentPos;
241233

242234
void initializeBlob(TBlobType _type, int batchLength, int batchWidth, int listSize, int height, int width,
243235
int depth, int channels);
244236
void initializeTensor(TBlobType _type, std::initializer_list<int> dimensions);
245-
void initializeWindow(const CPtr<CDnnBlob>& _parent, int windowSize);
246237
void initializeByPattern(TBlobType type, const CBlobDesc& pattern);
247238

248239
friend class CDnnBlobClassRegistrar;
240+
friend class CDnnWindowBlob;
241+
friend class CDnnBlobView;
242+
};
243+
244+
//---------------------------------------------------------------------------------------------------------------------
245+
246+
// The kind of CDnnBlob does not own the data
247+
// CDnnBlobView does not clear data handler
248+
// Used in python wrappers
249+
class NEOML_API CDnnBlobView : public CDnnBlob {
250+
protected:
251+
CDnnBlobView( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
252+
CDnnBlobView( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
253+
CDnnBlob( mathEngine, desc, data )
254+
{}
255+
256+
~CDnnBlobView() { if( !data.IsNull() ) { data = CMemoryHandle{}; } } // no need to free
257+
258+
// Prohibited methods
259+
//void ReinterpretDimensions( const CBlobDesc& ) override { NeoAssert( false ); } // impossible !!! USED in Python !!!
260+
//void Serialize( CArchive& ) override // !!! PICKLED in Python !!!
261+
//{ NeoAssert( false ); } // a blob that links to another may not be serialized
262+
void TransferDataToThisThread() override {} // !!! MOVED in Python !!!
263+
};
264+
265+
//---------------------------------------------------------------------------------------------------------------------
266+
267+
// The kind of CDnnBlob that is view of the some parent Blob as sequence (BatchLength > 1).
268+
// This CDnnWindowBlob do not owner of its memory, technical CDnnBlob representation.
269+
// This CDnnWindowBlob represents 1 element (BatchLength == 1) of the sequence for the recursive networks.
270+
class NEOML_API CDnnWindowBlob : public CDnnBlob {
271+
public:
272+
// Creates a "window" blob to represent a subsequence of objects from the parent blob
273+
static CDnnBlob* CreateWindowBlob( const CPtr<CDnnBlob>& parent, int windowSize = 1 );
274+
275+
// Prohibited methods
276+
void ReinterpretDimensions( const CBlobDesc& ) override { NeoAssert( false ); } // impossible
277+
void Serialize( CArchive& ) override { NeoAssert( false ); } // a blob that links to another may not be serialized
278+
void TransferDataToThisThread() override { NeoAssert( false ); }
279+
280+
// Interface of communication
281+
CDnnBlob* GetParent() override { return parent; }
282+
CDnnBlob* GetOwner() override;
283+
int GetParentPos() const override;
284+
void SetParentPos( int pos ) override;
285+
void ShiftParentPos( int shift ) override;
286+
287+
protected:
288+
CDnnWindowBlob( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
289+
CDnnWindowBlob( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
290+
CDnnBlob( mathEngine, desc, data )
291+
{}
292+
CDnnWindowBlob( CDnnWindowBlob&& other ) = delete;
293+
CDnnWindowBlob& operator=( CDnnWindowBlob&& other ) = delete;
294+
295+
~CDnnWindowBlob() { if( parent != nullptr ) { data = CMemoryHandle{}; } } // no need to free
296+
297+
private:
298+
// Pointer to blob with data for sequential recurent mode or reference dnn's paramBlobs
299+
CPtr<CDnnBlob> parent;
300+
// Offset in `parent` blob for sequential recurent mode, move window by BatchLength of the parent blob
301+
int parentPos = 0;
302+
303+
void initializeWindow( const CPtr<CDnnBlob>& parent, int windowSize );
249304
};
250305

306+
//---------------------------------------------------------------------------------------------------------------------
307+
251308
inline void SerializeBlob( IMathEngine& mathEngine, CArchive& archive, CPtr<CDnnBlob>& blob )
252309
{
253310
if( archive.IsStoring() ) {
@@ -287,6 +344,8 @@ inline void SerializeBlobs( IMathEngine& mathEngine, CArchive& archive, CObjectA
287344
}
288345
}
289346

347+
//---------------------------------------------------------------------------------------------------------------------
348+
290349
enum class TDnnBlobBufferAccess {
291350
Read,
292351
Write,
@@ -333,6 +392,8 @@ class CDnnBlobBuffer {
333392
TBufferType* ptr;
334393
};
335394

395+
//---------------------------------------------------------------------------------------------------------------------
396+
336397
inline CDnnBlob* CDnnBlob::CreateBlob( IMathEngine& mathEngine, const CBlobDesc& pattern )
337398
{
338399
return CreateBlob(mathEngine, CT_Float, pattern);
@@ -456,13 +517,15 @@ inline T* CDnnBlob::GetBuffer( int pos, int size, bool exchange )
456517
return static_cast<T*>( mathEngine.GetBuffer( data, pos * dataSize, size * dataSize, exchange ) );
457518
}
458519

459-
inline int CDnnBlob::GetParentPos() const
520+
//---------------------------------------------------------------------------------------------------------------------
521+
522+
inline int CDnnWindowBlob::GetParentPos() const
460523
{
461524
NeoAssert(parent != 0);
462525
return parentPos;
463526
}
464527

465-
inline void CDnnBlob::SetParentPos(int pos)
528+
inline void CDnnWindowBlob::SetParentPos(int pos)
466529
{
467530
int arrayPos = pos * (desc.BlobSize() / desc.BatchLength());
468531
NeoAssert(parent != 0);
@@ -480,25 +543,24 @@ inline void CDnnBlob::SetParentPos(int pos)
480543
}
481544
}
482545

483-
inline void CDnnBlob::ShiftParentPos(int shift)
546+
inline void CDnnWindowBlob::ShiftParentPos(int shift)
484547
{
485548
SetParentPos(parentPos + shift);
486549
}
487550

488-
inline bool CDnnBlob::HasEqualDimensions(const CDnnBlob* other) const
489-
{
490-
return desc.HasEqualDimensions(other->desc);
491-
}
492-
493-
inline CDnnBlob* CDnnBlob::GetOwner()
551+
inline CDnnBlob* CDnnWindowBlob::GetOwner()
494552
{
495553
CDnnBlob* result = this;
496-
while( result->parent != 0 ) {
497-
result = result->parent;
554+
CDnnWindowBlob* window = this;
555+
while( window != nullptr && window->parent != nullptr ) {
556+
result = window->parent;
557+
window = dynamic_cast<CDnnWindowBlob*>( result );
498558
}
499559
return result;
500560
}
501561

562+
//---------------------------------------------------------------------------------------------------------------------
563+
502564
template<typename TBufferType>
503565
inline CDnnBlobBuffer<TBufferType>::~CDnnBlobBuffer()
504566
{

NeoML/src/Dnn/AutoDiff.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2021 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -21,14 +21,14 @@ limitations under the License.
2121
namespace NeoML {
2222

2323
CTapeBlob::CTapeBlob( IGradientTape* _tape, const CDnnBlob& blob ) :
24-
CDnnBlob( blob.GetMathEngine(), blob.GetDesc(), blob.GetMathEngine().HeapAlloc( blob.GetDataSize() * sizeof(float) ), true ),
24+
CDnnBlob( blob.GetMathEngine(), blob.GetDesc(), blob.GetMathEngine().HeapAlloc( blob.GetDataSize() * sizeof(float) ) ),
2525
tape( _tape )
2626
{
2727
blob.GetMathEngine().VectorCopy( GetData(), blob.GetData(), blob.GetDataSize() );
2828
}
2929

3030
CTapeBlob::CTapeBlob( IGradientTape* _tape, IMathEngine& mathEngine, const CBlobDesc& desc ) :
31-
CDnnBlob( mathEngine, desc, mathEngine.HeapAlloc( desc.BlobSize() * sizeof(float) ), true ),
31+
CDnnBlob( mathEngine, desc, mathEngine.HeapAlloc( desc.BlobSize() * sizeof(float) ) ),
3232
tape( _tape )
3333
{
3434
}

0 commit comments

Comments
 (0)