@@ -34,9 +34,9 @@ namespace NeoML {
34
34
35
35
class NEOML_API CDnnBlob : public IObject {
36
36
public:
37
- explicit CDnnBlob ( IMathEngine& mathEngine );
37
+ explicit CDnnBlob ( IMathEngine& mathEngine ) : mathEngine( mathEngine ) {}
38
38
39
- // Move other's Blob state to this Blob and transfer its data (if dataOwned) to this thread
39
+ // Move other's Blob state to this Blob and transfer its data to this thread
40
40
CDnnBlob ( CDnnBlob&& other );
41
41
CDnnBlob& operator =( CDnnBlob&& other );
42
42
@@ -65,7 +65,7 @@ class NEOML_API CDnnBlob : public IObject {
65
65
static CDnnBlob* CreateBlob (IMathEngine& mathEngine, TBlobType type, const CBlobDesc& pattern);
66
66
67
67
// Checks if the dimensions of another blob are the same
68
- bool HasEqualDimensions (const CDnnBlob* other) const ;
68
+ bool HasEqualDimensions ( const CDnnBlob* other ) const { return desc. HasEqualDimensions ( other-> desc ); }
69
69
70
70
// Gets the blob size along the specified dimension
71
71
int DimSize (int d) const { return desc.DimSize (d); }
@@ -143,7 +143,7 @@ class NEOML_API CDnnBlob : public IObject {
143
143
// Transfers CDnnBlob data from other thread owner to this thread.
144
144
// By default memory underneath each blob is associated with the thread on which its allocation has occurred.
145
145
// This method switches this association to the calling thread.
146
- void TransferDataToThisThread ();
146
+ virtual void TransferDataToThisThread ();
147
147
148
148
// Elementwise adds a blob of the same dimensions
149
149
void Add (const CDnnBlob* other);
@@ -165,7 +165,7 @@ class NEOML_API CDnnBlob : public IObject {
165
165
// Changes the blob dimensions "names" without moving the data
166
166
// In effect, only the blob description is changed
167
167
// As the data is unaffected, the total blob size specified by the new descriptor should be the same
168
- void ReinterpretDimensions ( const CBlobDesc& newDesc );
168
+ virtual void ReinterpretDimensions ( const CBlobDesc& newDesc );
169
169
170
170
// Merges blobs along the given dimension
171
171
static void MergeByDim ( IMathEngine& mathEngine, TBlobDim d, const CObjectArray<CDnnBlob>& from, const CPtr<CDnnBlob>& to );
@@ -193,31 +193,31 @@ class NEOML_API CDnnBlob : public IObject {
193
193
194
194
// Gets the pointer to the MathEngine on which the blob was created
195
195
IMathEngine& GetMathEngine () const { return mathEngine; }
196
-
197
196
// Gets the blob descriptor
198
197
const CBlobDesc& GetDesc () const { return desc; }
199
198
// Gets the type of data in the blob
200
199
TBlobType GetDataType () const { return desc.GetDataType (); }
201
200
201
+ // All methods below are just the interface for the CDnnWindowBlob representation
202
+
202
203
// Gets the parent blob
203
- CDnnBlob* GetParent () { return parent ; }
204
- const CDnnBlob* GetParent () const { return parent ; }
204
+ virtual CDnnBlob* GetParent () { return nullptr ; }
205
+ const CDnnBlob* GetParent () const { return const_cast <CDnnBlob*>( this )-> GetParent () ; }
205
206
// Gets the blob that owns the data (and has no parent)
206
- CDnnBlob* GetOwner ();
207
- const CDnnBlob* GetOwner () const { return const_cast <CDnnBlob*>(this )->GetOwner (); }
208
-
207
+ virtual CDnnBlob* GetOwner () { return this ; }
208
+ const CDnnBlob* GetOwner () const { return const_cast <CDnnBlob*>( this )->GetOwner (); }
209
209
// Gets the shift in data relative to the parent blob
210
210
// The position in the parent blob is calculated along the BatchLength dimension
211
211
// The position equal to N would correspond to a N*BatchWidth*ListSize*Height*Width*Depth*Channels shift in the one-dimensional array
212
- int GetParentPos () const ;
213
- void SetParentPos ( int pos );
214
- void ShiftParentPos ( int shift );
212
+ virtual int GetParentPos () const { return 0 ; }
213
+ virtual void SetParentPos ( int /* pos*/ ) { NeoAssert ( false ); }
214
+ virtual void ShiftParentPos ( int /* shift*/ ) { NeoAssert ( false ); }
215
215
216
216
protected:
217
217
~CDnnBlob () override ;
218
218
219
- CDnnBlob ( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data, bool _dataOwned ) :
220
- mathEngine ( _mathEngine ), desc( _desc ), data( _data ), dataOwned( _dataOwned ), parentPos( 0 )
219
+ CDnnBlob ( IMathEngine& _mathEngine, const CBlobDesc& _desc, CMemoryHandle _data ) :
220
+ mathEngine ( _mathEngine ), desc( _desc ), data( _data )
221
221
{
222
222
NeoAssert ( desc.GetDataType () != CT_Invalid );
223
223
NeoAssert ( &mathEngine == data.GetMathEngine () );
@@ -227,20 +227,79 @@ class NEOML_API CDnnBlob : public IObject {
227
227
IMathEngine& mathEngine;
228
228
CBlobDesc desc;
229
229
CMemoryHandle data;
230
- bool dataOwned;
231
-
232
- CPtr<CDnnBlob> parent; // parent blob
233
- int parentPos;
234
230
235
231
void initializeBlob (TBlobType _type, int batchLength, int batchWidth, int listSize, int height, int width,
236
232
int depth, int channels);
237
233
void initializeTensor (TBlobType _type, std::initializer_list<int > dimensions);
238
- void initializeWindow (const CPtr<CDnnBlob>& _parent, int windowSize);
239
234
void initializeByPattern (TBlobType type, const CBlobDesc& pattern);
240
235
241
236
friend class CDnnBlobClassRegistrar ;
237
+ friend class CDnnWindowBlob ;
238
+ friend class CDnnBlobView ;
242
239
};
243
240
241
+ // ---------------------------------------------------------------------------------------------------------------------
242
+
243
+ // The kind of CDnnBlob does not own the data
244
+ // CDnnBlobView does not clear data handler
245
+ // Used in python wrappers
246
+ class NEOML_API CDnnBlobView : public CDnnBlob {
247
+ protected:
248
+ CDnnBlobView ( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
249
+ CDnnBlobView ( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
250
+ CDnnBlob ( mathEngine, desc, data )
251
+ {}
252
+
253
+ ~CDnnBlobView () { if ( !data.IsNull () ) { data = CMemoryHandle{}; } } // no need to free
254
+
255
+ // Prohibited methods
256
+ // void ReinterpretDimensions( const CBlobDesc& ) override { NeoAssert( false ); } // impossible !!! USED in Python !!!
257
+ // void Serialize( CArchive& ) override // !!! PICKLED in Python !!!
258
+ // { NeoAssert( false ); } // a blob that links to another may not be serialized
259
+ void TransferDataToThisThread () override {} // !!! MOVED in Python !!!
260
+ };
261
+
262
+ // ---------------------------------------------------------------------------------------------------------------------
263
+
264
+ // The kind of CDnnBlob that is view of the some parent Blob as sequence (BatchLength > 1).
265
+ // This CDnnWindowBlob do not owner of its memory, technical CDnnBlob representation.
266
+ // This CDnnWindowBlob represents 1 element (BatchLength == 1) of the sequence for the recursive networks.
267
+ class NEOML_API CDnnWindowBlob : public CDnnBlob {
268
+ public:
269
+ // Creates a "window" blob to represent a subsequence of objects from the parent blob
270
+ static CDnnBlob* CreateWindowBlob ( const CPtr<CDnnBlob>& parent, int windowSize = 1 );
271
+
272
+ // Prohibited methods
273
+ void ReinterpretDimensions ( const CBlobDesc& ) override { NeoAssert ( false ); } // impossible
274
+ void Serialize ( CArchive& ) override { NeoAssert ( false ); } // a blob that links to another may not be serialized
275
+ void TransferDataToThisThread () override { NeoAssert ( false ); }
276
+
277
+ // Interface of communication
278
+ CDnnBlob* GetParent () override { return parent; }
279
+ CDnnBlob* GetOwner () override ;
280
+ int GetParentPos () const override ;
281
+ void SetParentPos ( int pos ) override ;
282
+ void ShiftParentPos ( int shift ) override ;
283
+
284
+ protected:
285
+ CDnnWindowBlob ( IMathEngine& mathEngine ) : CDnnBlob( mathEngine ) {}
286
+ CDnnWindowBlob ( IMathEngine& mathEngine, const CBlobDesc& desc, CMemoryHandle data ) :
287
+ CDnnBlob ( mathEngine, desc, data )
288
+ {}
289
+ CDnnWindowBlob ( CDnnWindowBlob&& other ) = delete ;
290
+ CDnnWindowBlob& operator =( CDnnWindowBlob&& other ) = delete ;
291
+
292
+ ~CDnnWindowBlob () { if ( parent != nullptr ) { data = CMemoryHandle{}; } } // no need to free
293
+
294
+ private:
295
+ CPtr<CDnnBlob> parent; // parent blob
296
+ int parentPos = 0 ;
297
+
298
+ void initializeWindow ( const CPtr<CDnnBlob>& parent, int windowSize );
299
+ };
300
+
301
+ // ---------------------------------------------------------------------------------------------------------------------
302
+
244
303
inline void SerializeBlob ( IMathEngine& mathEngine, CArchive& archive, CPtr<CDnnBlob>& blob )
245
304
{
246
305
if ( archive.IsStoring () ) {
@@ -286,6 +345,8 @@ enum class TDnnBlobBufferAccess {
286
345
ReadWrite
287
346
};
288
347
348
+ // ---------------------------------------------------------------------------------------------------------------------
349
+
289
350
// RAII-helper to safely work with `CDnnBlob::GetBuffer`/`CDnnBlob::ReleaseBuffer`.
290
351
template <typename TBufferType = float >
291
352
class CDnnBlobBuffer {
@@ -326,6 +387,8 @@ class CDnnBlobBuffer {
326
387
TBufferType* ptr;
327
388
};
328
389
390
+ // ---------------------------------------------------------------------------------------------------------------------
391
+
329
392
inline CDnnBlob* CDnnBlob::CreateBlob ( IMathEngine& mathEngine, const CBlobDesc& pattern )
330
393
{
331
394
return CreateBlob (mathEngine, CT_Float, pattern);
@@ -449,13 +512,15 @@ inline T* CDnnBlob::GetBuffer( int pos, int size, bool exchange )
449
512
return static_cast <T*>( mathEngine.GetBuffer ( data, pos * dataSize, size * dataSize, exchange ) );
450
513
}
451
514
452
- inline int CDnnBlob::GetParentPos () const
515
+ // ---------------------------------------------------------------------------------------------------------------------
516
+
517
+ inline int CDnnWindowBlob::GetParentPos () const
453
518
{
454
519
NeoAssert (parent != 0 );
455
520
return parentPos;
456
521
}
457
522
458
- inline void CDnnBlob ::SetParentPos (int pos)
523
+ inline void CDnnWindowBlob ::SetParentPos (int pos)
459
524
{
460
525
int arrayPos = pos * (desc.BlobSize () / desc.BatchLength ());
461
526
NeoAssert (parent != 0 );
@@ -473,25 +538,24 @@ inline void CDnnBlob::SetParentPos(int pos)
473
538
}
474
539
}
475
540
476
- inline void CDnnBlob ::ShiftParentPos (int shift)
541
+ inline void CDnnWindowBlob ::ShiftParentPos (int shift)
477
542
{
478
543
SetParentPos (parentPos + shift);
479
544
}
480
545
481
- inline bool CDnnBlob::HasEqualDimensions (const CDnnBlob* other) const
482
- {
483
- return desc.HasEqualDimensions (other->desc );
484
- }
485
-
486
- inline CDnnBlob* CDnnBlob::GetOwner ()
546
+ inline CDnnBlob* CDnnWindowBlob::GetOwner ()
487
547
{
488
548
CDnnBlob* result = this ;
489
- while ( result->parent != 0 ) {
490
- result = result->parent ;
549
+ CDnnWindowBlob* window = this ;
550
+ while ( window != nullptr && window->parent != nullptr ) {
551
+ result = window->parent ;
552
+ window = dynamic_cast <CDnnWindowBlob*>( result );
491
553
}
492
554
return result;
493
555
}
494
556
557
+ // ---------------------------------------------------------------------------------------------------------------------
558
+
495
559
template <typename TBufferType>
496
560
inline CDnnBlobBuffer<TBufferType>::~CDnnBlobBuffer ()
497
561
{
0 commit comments