@@ -34,9 +34,9 @@ namespace NeoML {
34
34
35
35
class NEOML_API CDnnBlob : public IObject {
36
36
public:
37
- explicit CDnnBlob ( IMathEngine& mathEngine );
37
+ explicit CDnnBlob ( IMathEngine& mathEngine ) : mathEngine( mathEngine ) {}
38
38
39
- // Move other's Blob state to this Blob and transfer its data (if dataOwned) to this thread
39
+ // Move other's Blob state to this Blob and transfer its data to this thread
40
40
CDnnBlob ( CDnnBlob&& other );
41
41
CDnnBlob& operator =( CDnnBlob&& other );
42
42
@@ -65,7 +65,7 @@ class NEOML_API CDnnBlob : public IObject {
65
65
static CDnnBlob* CreateBlob (IMathEngine& mathEngine, TBlobType type, const CBlobDesc& pattern);
66
66
67
67
// Checks if the dimensions of another blob are the same
68
- bool HasEqualDimensions (const CDnnBlob* other) const ;
68
+ bool HasEqualDimensions ( const CDnnBlob* other ) const { return desc. HasEqualDimensions ( other-> desc ); }
69
69
70
70
// Gets the blob size along the specified dimension
71
71
int DimSize (int d) const { return desc.DimSize (d); }
@@ -143,7 +143,7 @@ class NEOML_API CDnnBlob : public IObject {
143
143
// Transfers CDnnBlob data from other thread owner to this thread.
144
144
// By default memory underneath each blob is associated with the thread on which its allocation has occurred.
145
145
// This method switches this association to the calling thread.
146
- void TransferDataToThisThread ();
146
+ virtual void TransferDataToThisThread ();
147
147
148
148
// Elementwise adds a blob of the same dimensions
149
149
void Add (const CDnnBlob* other);
@@ -165,7 +165,7 @@ class NEOML_API CDnnBlob : public IObject {
165
165
// Changes the blob dimensions "names" without moving the data
166
166
// In effect, only the blob description is changed
167
167
// As the data is unaffected, the total blob size specified by the new descriptor should be the same
168
- void ReinterpretDimensions ( const CBlobDesc& newDesc );
168
+ virtual void ReinterpretDimensions ( const CBlobDesc& newDesc );
169
169
170
170
// Merges blobs along the given dimension
171
171
static void MergeByDim ( IMathEngine& mathEngine, TBlobDim d, const CObjectArray<CDnnBlob>& from, const CPtr<CDnnBlob>& to );
@@ -193,31 +193,31 @@ class NEOML_API CDnnBlob : public IObject {
193
193
194
194
// Gets the pointer to the MathEngine on which the blob was created
195
195
IMathEngine& GetMathEngine () const { return mathEngine; }
196
-
197
196
// Gets the blob descriptor
198
197
const CBlobDesc& GetDesc () const { return desc; }
199
198
// Gets the type of data in the blob
200
199
TBlobType GetDataType () const { return desc.GetDataType (); }
201
200
201
+ // All methods below are just the interface for the CDnnWindowBlob representation
202
+
202
203
// Gets the parent blob
203
- CDnnBlob* GetParent () { return parent ; }
204
- const CDnnBlob* GetParent () const { return parent ; }
204
+ virtual CDnnBlob* GetParent () { return nullptr ; }
205
+ const CDnnBlob* GetParent () const { return const_cast <CDnnBlob*>( this )-> GetParent () ; }
205
206
// Gets the blob that owns the data (and has no parent)
206
- CDnnBlob* GetOwner ();
207
- const CDnnBlob* GetOwner () const { return const_cast <CDnnBlob*>(this )->GetOwner (); }
208
-
207
+ virtual CDnnBlob* GetOwner () { return this ; }
208
+ const CDnnBlob* GetOwner () const { return const_cast <CDnnBlob*>( this )->GetOwner (); }
209
209
// Gets the shift in data relative to the parent blob
210
210
// The position in the parent blob is calculated along the BatchLength dimension
211
211
// The position equal to N would correspond to a N*BatchWidth*ListSize*Height*Width*Depth*Channels shift in the one-dimensional array
212
- int GetParentPos () const ;
213
- void SetParentPos ( int pos );
214
- void ShiftParentPos ( int shift );
212
+ virtual int GetParentPos () const { return 0 ; }
213
+ virtual void SetParentPos ( int /* pos*/ ) { NeoAssert ( false ); }
214
+ virtual void ShiftParentPos ( int /* shift*/ ) { NeoAssert ( false ); }
215
215
216
216
protected:
217
217
~CDnnBlob () override ;
218
218
219
- CDnnBlob ( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data, bool _dataOwned ) :
220
- mathEngine ( _mathEngine ), desc( _desc ), data( _data ), dataOwned( _dataOwned ), parentPos( 0 )
219
+ CDnnBlob ( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data ) :
220
+ mathEngine ( _mathEngine ), desc( _desc ), data( _data )
221
221
{
222
222
NeoAssert ( desc.GetDataType () != CT_Invalid );
223
223
NeoAssert ( &mathEngine == data.GetMathEngine () );
@@ -230,24 +230,81 @@ class NEOML_API CDnnBlob : public IObject {
230
230
CBlobDesc desc;
231
231
// Pointer to the allocated data storage
232
232
CMemoryHandle data;
233
- // Ownership of the `data`, it means that it has full access to write and to free the allocated data storage
234
- // Either `dataOwned` is true and `parent` is 0
235
- // Or `dataOwned` is false and `parent` is pointer to blob that owns the allocated data storage
236
- bool dataOwned;
237
- // Pointer to blob with data for sequential recurent mode or reference dnn's paramBlobs
238
- CPtr<CDnnBlob> parent;
239
- // Offset in `parent` blob for sequential recurent mode, move window by BatchLength of the parent blob
240
- int parentPos;
241
233
242
234
void initializeBlob (TBlobType _type, int batchLength, int batchWidth, int listSize, int height, int width,
243
235
int depth, int channels);
244
236
void initializeTensor (TBlobType _type, std::initializer_list<int > dimensions);
245
- void initializeWindow (const CPtr<CDnnBlob>& _parent, int windowSize);
246
237
void initializeByPattern (TBlobType type, const CBlobDesc& pattern);
247
238
248
239
friend class CDnnBlobClassRegistrar ;
240
+ friend class CDnnWindowBlob ;
241
+ friend class CDnnBlobView ;
242
+ };
243
+
244
+ // ---------------------------------------------------------------------------------------------------------------------
245
+
246
+ // The kind of CDnnBlob does not own the data
247
+ // CDnnBlobView does not clear data handler
248
+ // Used in python wrappers
249
+ class NEOML_API CDnnBlobView : public CDnnBlob {
250
+ protected:
251
+ CDnnBlobView ( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
252
+ CDnnBlobView ( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
253
+ CDnnBlob ( mathEngine, desc, data )
254
+ {}
255
+
256
+ ~CDnnBlobView () { if ( !data.IsNull () ) { data = CMemoryHandle{}; } } // no need to free
257
+
258
+ // Prohibited methods
259
+ // void ReinterpretDimensions( const CBlobDesc& ) override { NeoAssert( false ); } // impossible !!! USED in Python !!!
260
+ // void Serialize( CArchive& ) override // !!! PICKLED in Python !!!
261
+ // { NeoAssert( false ); } // a blob that links to another may not be serialized
262
+ void TransferDataToThisThread () override {} // !!! MOVED in Python !!!
263
+ };
264
+
265
+ // ---------------------------------------------------------------------------------------------------------------------
266
+
267
+ // The kind of CDnnBlob that is view of the some parent Blob as sequence (BatchLength > 1).
268
+ // This CDnnWindowBlob do not owner of its memory, technical CDnnBlob representation.
269
+ // This CDnnWindowBlob represents 1 element (BatchLength == 1) of the sequence for the recursive networks.
270
+ class NEOML_API CDnnWindowBlob : public CDnnBlob {
271
+ public:
272
+ // Creates a "window" blob to represent a subsequence of objects from the parent blob
273
+ static CDnnBlob* CreateWindowBlob ( const CPtr<CDnnBlob>& parent, int windowSize = 1 );
274
+
275
+ // Prohibited methods
276
+ void ReinterpretDimensions ( const CBlobDesc& ) override { NeoAssert ( false ); } // impossible
277
+ void Serialize ( CArchive& ) override { NeoAssert ( false ); } // a blob that links to another may not be serialized
278
+ void TransferDataToThisThread () override { NeoAssert ( false ); }
279
+
280
+ // Interface of communication
281
+ CDnnBlob* GetParent () override { return parent; }
282
+ CDnnBlob* GetOwner () override ;
283
+ int GetParentPos () const override ;
284
+ void SetParentPos ( int pos ) override ;
285
+ void ShiftParentPos ( int shift ) override ;
286
+
287
+ protected:
288
+ CDnnWindowBlob ( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
289
+ CDnnWindowBlob ( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
290
+ CDnnBlob ( mathEngine, desc, data )
291
+ {}
292
+ CDnnWindowBlob ( CDnnWindowBlob&& other ) = delete ;
293
+ CDnnWindowBlob& operator =( CDnnWindowBlob&& other ) = delete ;
294
+
295
+ ~CDnnWindowBlob () { if ( parent != nullptr ) { data = CMemoryHandle{}; } } // no need to free
296
+
297
+ private:
298
+ // Pointer to blob with data for sequential recurent mode or reference dnn's paramBlobs
299
+ CPtr<CDnnBlob> parent;
300
+ // Offset in `parent` blob for sequential recurent mode, move window by BatchLength of the parent blob
301
+ int parentPos = 0 ;
302
+
303
+ void initializeWindow ( const CPtr<CDnnBlob>& parent, int windowSize );
249
304
};
250
305
306
+ // ---------------------------------------------------------------------------------------------------------------------
307
+
251
308
inline void SerializeBlob ( IMathEngine& mathEngine, CArchive& archive, CPtr<CDnnBlob>& blob )
252
309
{
253
310
if ( archive.IsStoring () ) {
@@ -287,6 +344,8 @@ inline void SerializeBlobs( IMathEngine& mathEngine, CArchive& archive, CObjectA
287
344
}
288
345
}
289
346
347
+ // ---------------------------------------------------------------------------------------------------------------------
348
+
290
349
enum class TDnnBlobBufferAccess {
291
350
Read,
292
351
Write,
@@ -333,6 +392,8 @@ class CDnnBlobBuffer {
333
392
TBufferType* ptr;
334
393
};
335
394
395
+ // ---------------------------------------------------------------------------------------------------------------------
396
+
336
397
inline CDnnBlob* CDnnBlob::CreateBlob ( IMathEngine& mathEngine, const CBlobDesc& pattern )
337
398
{
338
399
return CreateBlob (mathEngine, CT_Float, pattern);
@@ -456,13 +517,15 @@ inline T* CDnnBlob::GetBuffer( int pos, int size, bool exchange )
456
517
return static_cast <T*>( mathEngine.GetBuffer ( data, pos * dataSize, size * dataSize, exchange ) );
457
518
}
458
519
459
- inline int CDnnBlob::GetParentPos () const
520
+ // ---------------------------------------------------------------------------------------------------------------------
521
+
522
+ inline int CDnnWindowBlob::GetParentPos () const
460
523
{
461
524
NeoAssert (parent != 0 );
462
525
return parentPos;
463
526
}
464
527
465
- inline void CDnnBlob ::SetParentPos (int pos)
528
+ inline void CDnnWindowBlob ::SetParentPos (int pos)
466
529
{
467
530
int arrayPos = pos * (desc.BlobSize () / desc.BatchLength ());
468
531
NeoAssert (parent != 0 );
@@ -480,25 +543,24 @@ inline void CDnnBlob::SetParentPos(int pos)
480
543
}
481
544
}
482
545
483
- inline void CDnnBlob ::ShiftParentPos (int shift)
546
+ inline void CDnnWindowBlob ::ShiftParentPos (int shift)
484
547
{
485
548
SetParentPos (parentPos + shift);
486
549
}
487
550
488
- inline bool CDnnBlob::HasEqualDimensions (const CDnnBlob* other) const
489
- {
490
- return desc.HasEqualDimensions (other->desc );
491
- }
492
-
493
- inline CDnnBlob* CDnnBlob::GetOwner ()
551
+ inline CDnnBlob* CDnnWindowBlob::GetOwner ()
494
552
{
495
553
CDnnBlob* result = this ;
496
- while ( result->parent != 0 ) {
497
- result = result->parent ;
554
+ CDnnWindowBlob* window = this ;
555
+ while ( window != nullptr && window->parent != nullptr ) {
556
+ result = window->parent ;
557
+ window = dynamic_cast <CDnnWindowBlob*>( result );
498
558
}
499
559
return result;
500
560
}
501
561
562
+ // ---------------------------------------------------------------------------------------------------------------------
563
+
502
564
template <typename TBufferType>
503
565
inline CDnnBlobBuffer<TBufferType>::~CDnnBlobBuffer ()
504
566
{
0 commit comments