@@ -165,6 +165,7 @@ struct CDistributedTraining::CThreadParams final {
165
165
bool * const IsFirstRun;
166
166
IDistributedDataset* const Data;
167
167
CPointerArray<CDnn>& Dnns;
168
+ CArray<bool >* IsDnnInferenced;
168
169
CArray<int >& BatchSize;
169
170
const bool IsCpu;
170
171
CArray<CString>& ErrorMessages;
@@ -173,19 +174,25 @@ struct CDistributedTraining::CThreadParams final {
173
174
174
175
// RunOnce and RunAndBackwardOnce
175
176
CThreadParams ( bool * isFirstRun, IDistributedDataset* data, CPointerArray<CDnn>& dnns,
176
- CArray<int >& batchSize, bool isCpu, CArray<CString>& errorMessages ) :
177
+ CArray<bool >* isDnnInferenced, CArray< int >& batchSize, bool isCpu, CArray<CString>& errorMessages ) :
177
178
IsFirstRun ( isFirstRun ),
178
179
Data ( data ),
179
180
Dnns ( dnns ),
181
+ IsDnnInferenced ( isDnnInferenced ),
180
182
BatchSize ( batchSize ),
181
183
IsCpu ( isCpu ),
182
184
ErrorMessages ( errorMessages )
183
- {}
185
+ {
186
+ if ( IsDnnInferenced != nullptr ) {
187
+ IsDnnInferenced->DeleteAll ();
188
+ IsDnnInferenced->Add ( false , Dnns.Size () );
189
+ }
190
+ }
184
191
185
192
// solver.Train
186
193
CThreadParams ( CPointerArray<CDnn>& dnns,
187
194
CArray<int >& batchSize, int totalBatch, bool isCpu, CArray<CString>& errorMessages ) :
188
- CThreadParams ( nullptr , nullptr , dnns, batchSize, isCpu, errorMessages )
195
+ CThreadParams ( nullptr , nullptr , dnns, nullptr , batchSize, isCpu, errorMessages )
189
196
{ TotalBatch = totalBatch; }
190
197
191
198
void SetErrorMessage ( int threadIndex, CString message );
@@ -195,6 +202,7 @@ void CDistributedTraining::CThreadParams::SetErrorMessage( int threadIndex, CStr
195
202
{
196
203
IsErrorHappened = true ;
197
204
ErrorMessages[threadIndex] = std::move ( message );
205
+ ErrorMessages[threadIndex] += " (thread = " + Str ( threadIndex ) + " )" ;
198
206
// This abort is monitored only inside:
199
207
// - CDnnSolver::allReduce (MathEngine.AllReduce)
200
208
// - CDnnDistributedInitializer::InitializeLayerParams (MathEngine.Broadcast)
@@ -217,6 +225,7 @@ void CDistributedTraining::initialize( CArchive& archive, int count, TDistribute
217
225
archive.Serialize ( *cnns[i] );
218
226
archive.Seek ( 0 , static_cast <CBaseFile::TSeekPosition>( 0 ) );
219
227
}
228
+ isDnnInferenced.Add ( false , count );
220
229
batchSize.Add ( 0 , count );
221
230
errorMessages.Add ( {}, count );
222
231
}
@@ -342,22 +351,20 @@ float CDistributedTraining::GetLearningRate() const
342
351
343
352
void CDistributedTraining::RunOnce ( IDistributedDataset& data )
344
353
{
345
- CThreadParams function_params ( &isFirstRun , &data, cnns, batchSize, isCpu, errorMessages );
354
+ CThreadParams function_params ( nullptr , &data, cnns, &isDnnInferenced , batchSize, isCpu, errorMessages );
346
355
347
356
IThreadPool::TFunction f = [](int threadIndex, void * ptr)
348
357
{
349
358
CThreadParams& function_params = *static_cast <CThreadParams*>( ptr );
350
359
CPointerArray<CDnn>& cnns = function_params.Dnns ;
351
- CArray<int >& batchSize = function_params.BatchSize ;
352
360
try {
353
361
CThreadGroupSwitcher groupSwitcher ( function_params.IsCpu , threadIndex, cnns.Size () );
362
+ // Returns the current batch size (or 0, if there is no data for this thread on this run)
354
363
const int currBatchSize = function_params.Data ->SetInputBatch ( *cnns[threadIndex], threadIndex );
355
- NeoAssert ( currBatchSize > 0 || ( currBatchSize == 0 && !( *function_params.IsFirstRun ) ) );
356
364
if ( currBatchSize > 0 ) {
357
- batchSize[threadIndex] += currBatchSize;
358
365
cnns[threadIndex]->RunOnce ();
366
+ function_params.IsDnnInferenced ->ReplaceAt ( true , threadIndex );
359
367
}
360
- *function_params.IsFirstRun = false ;
361
368
} catch ( std::exception& e ) {
362
369
function_params.SetErrorMessage ( threadIndex, e.what () );
363
370
}
@@ -376,7 +383,7 @@ void CDistributedTraining::RunOnce( IDistributedDataset& data )
376
383
377
384
void CDistributedTraining::RunAndBackwardOnce ( IDistributedDataset& data )
378
385
{
379
- CThreadParams function_params ( &isFirstRun, &data, cnns, batchSize, isCpu, errorMessages );
386
+ CThreadParams function_params ( &isFirstRun, &data, cnns, &isDnnInferenced, batchSize, isCpu, errorMessages );
380
387
381
388
IThreadPool::TFunction f = [](int threadIndex, void * ptr)
382
389
{
@@ -385,11 +392,15 @@ void CDistributedTraining::RunAndBackwardOnce( IDistributedDataset& data )
385
392
CArray<int >& batchSize = function_params.BatchSize ;
386
393
try {
387
394
CThreadGroupSwitcher groupSwitcher ( function_params.IsCpu , threadIndex, cnns.Size () );
395
+ // Returns the current batch size (or 0, if there is no data for this thread on this run)
388
396
const int currBatchSize = function_params.Data ->SetInputBatch ( *cnns[threadIndex], threadIndex );
397
+ // Cannot avoid this assert, in solver->Train() should participate all of dnns
389
398
NeoAssert ( currBatchSize > 0 || ( currBatchSize == 0 && !( *function_params.IsFirstRun ) ) );
390
399
if ( currBatchSize > 0 ) {
391
400
batchSize[threadIndex] += currBatchSize;
392
401
cnns[threadIndex]->RunAndBackwardOnce ();
402
+ // May want retreive the sinks results after this, because RunOnce() was launched
403
+ function_params.IsDnnInferenced ->ReplaceAt ( true , threadIndex );
393
404
}
394
405
*function_params.IsFirstRun = false ;
395
406
} catch ( std::exception& e ) {
@@ -474,17 +485,23 @@ void CDistributedTraining::GetLastLoss( const CString& layerName, CArray<float>&
474
485
void CDistributedTraining::GetLastBlob ( const CString& layerName, CObjectArray<const CDnnBlob>& blobs ) const
475
486
{
476
487
blobs.SetSize ( cnns.Size () );
488
+ // Return blobs for all models
477
489
for ( int i = 0 ; i < cnns.Size (); ++i ) {
478
- blobs[i] = CheckCast<const CSinkLayer>( cnns[i]->GetLayer ( layerName ) )->GetBlob ();
490
+ blobs[i] = ( isDnnInferenced[i] )
491
+ ? CheckCast<const CSinkLayer>( cnns[i]->GetLayer ( layerName ) )->GetBlob ()
492
+ : nullptr ;
479
493
}
480
494
}
481
495
482
496
// deprecated
483
497
void CDistributedTraining::GetLastBlob ( const CString& layerName, CObjectArray<CDnnBlob>& blobs ) const
484
498
{
485
499
blobs.SetSize ( cnns.Size () );
500
+ // Return blobs for all models
486
501
for ( int i = 0 ; i < cnns.Size (); ++i ) {
487
- blobs[i] = CheckCast<const CSinkLayer>( cnns[i]->GetLayer ( layerName ) )->GetBlob ();
502
+ blobs[i] = ( isDnnInferenced[i] )
503
+ ? CheckCast<const CSinkLayer>( cnns[i]->GetLayer ( layerName ) )->GetBlob ()
504
+ : nullptr ;
488
505
}
489
506
}
490
507
@@ -518,6 +535,7 @@ struct CDistributedInference::CThreadParams final {
518
535
519
536
CThreadParams ( int threadsCount, CReferenceDnnFactory& referenceDnnFactory );
520
537
void Initialize ( IDistributedDataset& data );
538
+ void SetErrorMessage ( int threadIndex, CString message );
521
539
};
522
540
523
541
CDistributedInference::CThreadParams::CThreadParams ( int threadsCount, CReferenceDnnFactory& referenceDnnFactory )
@@ -546,6 +564,13 @@ void CDistributedInference::CThreadParams::Initialize( IDistributedDataset& data
546
564
IsErrorHappened = false ;
547
565
}
548
566
567
+ void CDistributedInference::CThreadParams::SetErrorMessage ( int threadIndex, CString message )
568
+ {
569
+ IsErrorHappened = true ;
570
+ ErrorMessages[threadIndex] = std::move ( message );
571
+ ErrorMessages[threadIndex] += " (thread = " + Str ( threadIndex ) + " )" ;
572
+ }
573
+
549
574
// ---------------------------------------------------------------------------------------------------------------------
550
575
551
576
CDistributedInference::CDistributedInference ( const CDnn& dnn, int threadsCount,
@@ -588,13 +613,11 @@ void CDistributedInference::RunOnce( IDistributedDataset& data )
588
613
threadParams.IsDnnInferenced [threadIndex] = true ;
589
614
}
590
615
} catch ( std::exception& e ) {
591
- threadParams.IsErrorHappened = true ;
592
- threadParams.ErrorMessages [threadIndex] = e.what ();
616
+ threadParams.SetErrorMessage ( threadIndex, e.what () );
593
617
}
594
618
#ifdef NEOML_USE_FINEOBJ
595
619
catch ( CException* e ) {
596
- threadParams.IsErrorHappened = true ;
597
- threadParams.ErrorMessages [threadIndex] = e->MessageText ().CreateString ();
620
+ threadParams.SetErrorMessage ( threadIndex, e->MessageText ().CreateString () );
598
621
delete e;
599
622
}
600
623
#endif // NEOML_USE_FINEOBJ
@@ -608,12 +631,12 @@ void CDistributedInference::RunOnce( IDistributedDataset& data )
608
631
609
632
void CDistributedInference::GetLastBlob ( const CString& layerName, CObjectArray<const CDnnBlob>& blobs ) const
610
633
{
611
- blobs.DeleteAll ( );
612
- blobs. Add ( nullptr , threadParams-> Refs . Size () );
634
+ blobs.SetSize ( threadParams-> Refs . Size () );
635
+ // Return blobs for all models
613
636
for ( int i = 0 ; i < threadParams->Refs .Size (); ++i ) {
614
- if ( threadParams->IsDnnInferenced [i] ) {
615
- blobs[i] = CheckCast<const CSinkLayer>( threadParams->Refs [i]->Dnn .GetLayer ( layerName ) )->GetBlob ();
616
- }
637
+ blobs[i] = ( threadParams->IsDnnInferenced [i] )
638
+ ? CheckCast<const CSinkLayer>( threadParams->Refs [i]->Dnn .GetLayer ( layerName ) )->GetBlob ()
639
+ : nullptr ;
617
640
}
618
641
}
619
642
0 commit comments