Skip to content

Commit ae3256d

Browse files
committed
[NeoML] DistributedTraining uses IsDnnInferenced
Signed-off-by: Kirill Golikov <[email protected]>
1 parent ef317f2 commit ae3256d

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

NeoML/include/NeoML/Dnn/DnnDistributed.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ class NEOML_API CDistributedTraining {
113113
CPointerArray<CRandom> rands;
114114
// Separate dnn for each thread
115115
CPointerArray<CDnn> cnns;
116+
// Indicates for what dnns the inference was performed
117+
CArray<bool> isDnnInferenced;
116118
// Separate `batchSize` for each dnn (may be empty) in a thread
117119
CArray<int> batchSize;
118120
// `Train()` cannot be called if it `isFirstRun`
119-
// `batchSize` may not be equal 0, if it `isFirstRun` for `RunOnce`, `RunAndBackwardOnce` or `RunAndLearnOnce`.
121+
// `batchSize` may not be equal 0, if it `isFirstRun` for `RunAndBackwardOnce` or `RunAndLearnOnce`.
120122
bool isFirstRun = true;
121123
// Containers for errors if it happened
122124
CArray<CString> errorMessages;

NeoML/src/Dnn/DnnDistributed.cpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct CDistributedTraining::CThreadParams final {
165165
bool* const IsFirstRun;
166166
IDistributedDataset* const Data;
167167
CPointerArray<CDnn>& Dnns;
168+
CArray<bool>* IsDnnInferenced;
168169
CArray<int>& BatchSize;
169170
const bool IsCpu;
170171
CArray<CString>& ErrorMessages;
@@ -173,19 +174,25 @@ struct CDistributedTraining::CThreadParams final {
173174

174175
// RunOnce and RunAndBackwardOnce
175176
CThreadParams( bool* isFirstRun, IDistributedDataset* data, CPointerArray<CDnn>& dnns,
176-
CArray<int>& batchSize, bool isCpu, CArray<CString>& errorMessages ) :
177+
CArray<bool>* isDnnInferenced, CArray<int>& batchSize, bool isCpu, CArray<CString>& errorMessages ) :
177178
IsFirstRun( isFirstRun ),
178179
Data( data ),
179180
Dnns( dnns ),
181+
IsDnnInferenced( isDnnInferenced ),
180182
BatchSize( batchSize ),
181183
IsCpu( isCpu ),
182184
ErrorMessages( errorMessages )
183-
{}
185+
{
186+
if( IsDnnInferenced != nullptr ) {
187+
IsDnnInferenced->DeleteAll();
188+
IsDnnInferenced->Add( false, Dnns.Size() );
189+
}
190+
}
184191

185192
// solver.Train
186193
CThreadParams( CPointerArray<CDnn>& dnns,
187194
CArray<int>& batchSize, int totalBatch, bool isCpu, CArray<CString>& errorMessages ) :
188-
CThreadParams( nullptr, nullptr, dnns, batchSize, isCpu, errorMessages )
195+
CThreadParams( nullptr, nullptr, dnns, nullptr, batchSize, isCpu, errorMessages )
189196
{ TotalBatch = totalBatch; }
190197

191198
void SetErrorMessage( int threadIndex, CString message );
@@ -195,6 +202,7 @@ void CDistributedTraining::CThreadParams::SetErrorMessage( int threadIndex, CStr
195202
{
196203
IsErrorHappened = true;
197204
ErrorMessages[threadIndex] = std::move( message );
205+
ErrorMessages[threadIndex] += "(thread = " + Str( threadIndex ) + ")";
198206
// This abort is monitored only inside:
199207
// - CDnnSolver::allReduce (MathEngine.AllReduce)
200208
// - CDnnDistributedInitializer::InitializeLayerParams (MathEngine.Broadcast)
@@ -217,6 +225,7 @@ void CDistributedTraining::initialize( CArchive& archive, int count, TDistribute
217225
archive.Serialize( *cnns[i] );
218226
archive.Seek( 0, static_cast<CBaseFile::TSeekPosition>( 0 ) );
219227
}
228+
isDnnInferenced.Add( false, count );
220229
batchSize.Add( 0, count );
221230
errorMessages.Add( {}, count );
222231
}
@@ -342,22 +351,20 @@ float CDistributedTraining::GetLearningRate() const
342351

343352
void CDistributedTraining::RunOnce( IDistributedDataset& data )
344353
{
345-
CThreadParams function_params( &isFirstRun, &data, cnns, batchSize, isCpu, errorMessages );
354+
CThreadParams function_params( nullptr, &data, cnns, &isDnnInferenced, batchSize, isCpu, errorMessages );
346355

347356
IThreadPool::TFunction f = [](int threadIndex, void* ptr)
348357
{
349358
CThreadParams& function_params = *static_cast<CThreadParams*>( ptr );
350359
CPointerArray<CDnn>& cnns = function_params.Dnns;
351-
CArray<int>& batchSize = function_params.BatchSize;
352360
try {
353361
CThreadGroupSwitcher groupSwitcher( function_params.IsCpu, threadIndex, cnns.Size() );
362+
// Returns the current batch size (or 0, if there is no data for this thread on this run)
354363
const int currBatchSize = function_params.Data->SetInputBatch( *cnns[threadIndex], threadIndex );
355-
NeoAssert( currBatchSize > 0 || ( currBatchSize == 0 && !( *function_params.IsFirstRun ) ) );
356364
if( currBatchSize > 0 ) {
357-
batchSize[threadIndex] += currBatchSize;
358365
cnns[threadIndex]->RunOnce();
366+
function_params.IsDnnInferenced->ReplaceAt( true, threadIndex );
359367
}
360-
*function_params.IsFirstRun = false;
361368
} catch( std::exception& e ) {
362369
function_params.SetErrorMessage( threadIndex, e.what() );
363370
}
@@ -376,7 +383,7 @@ void CDistributedTraining::RunOnce( IDistributedDataset& data )
376383

377384
void CDistributedTraining::RunAndBackwardOnce( IDistributedDataset& data )
378385
{
379-
CThreadParams function_params( &isFirstRun, &data, cnns, batchSize, isCpu, errorMessages );
386+
CThreadParams function_params( &isFirstRun, &data, cnns, &isDnnInferenced, batchSize, isCpu, errorMessages );
380387

381388
IThreadPool::TFunction f = [](int threadIndex, void* ptr)
382389
{
@@ -385,11 +392,15 @@ void CDistributedTraining::RunAndBackwardOnce( IDistributedDataset& data )
385392
CArray<int>& batchSize = function_params.BatchSize;
386393
try {
387394
CThreadGroupSwitcher groupSwitcher( function_params.IsCpu, threadIndex, cnns.Size() );
395+
// Returns the current batch size (or 0, if there is no data for this thread on this run)
388396
const int currBatchSize = function_params.Data->SetInputBatch( *cnns[threadIndex], threadIndex );
397+
// Cannot avoid this assert, in solver->Train() should participate all of dnns
389398
NeoAssert( currBatchSize > 0 || ( currBatchSize == 0 && !( *function_params.IsFirstRun ) ) );
390399
if( currBatchSize > 0 ) {
391400
batchSize[threadIndex] += currBatchSize;
392401
cnns[threadIndex]->RunAndBackwardOnce();
402+
// May want retreive the sinks results after this, because RunOnce() was launched
403+
function_params.IsDnnInferenced->ReplaceAt( true, threadIndex );
393404
}
394405
*function_params.IsFirstRun = false;
395406
} catch( std::exception& e ) {
@@ -474,17 +485,23 @@ void CDistributedTraining::GetLastLoss( const CString& layerName, CArray<float>&
474485
void CDistributedTraining::GetLastBlob( const CString& layerName, CObjectArray<const CDnnBlob>& blobs ) const
475486
{
476487
blobs.SetSize( cnns.Size() );
488+
// Return blobs for all models
477489
for( int i = 0; i < cnns.Size(); ++i ) {
478-
blobs[i] = CheckCast<const CSinkLayer>( cnns[i]->GetLayer( layerName ) )->GetBlob();
490+
blobs[i] = ( isDnnInferenced[i] )
491+
? CheckCast<const CSinkLayer>( cnns[i]->GetLayer( layerName ) )->GetBlob()
492+
: nullptr;
479493
}
480494
}
481495

482496
// deprecated
483497
void CDistributedTraining::GetLastBlob( const CString& layerName, CObjectArray<CDnnBlob>& blobs ) const
484498
{
485499
blobs.SetSize( cnns.Size() );
500+
// Return blobs for all models
486501
for( int i = 0; i < cnns.Size(); ++i ) {
487-
blobs[i] = CheckCast<const CSinkLayer>( cnns[i]->GetLayer( layerName ) )->GetBlob();
502+
blobs[i] = ( isDnnInferenced[i] )
503+
? CheckCast<const CSinkLayer>( cnns[i]->GetLayer( layerName ) )->GetBlob()
504+
: nullptr;
488505
}
489506
}
490507

@@ -518,6 +535,7 @@ struct CDistributedInference::CThreadParams final {
518535

519536
CThreadParams( int threadsCount, CReferenceDnnFactory& referenceDnnFactory );
520537
void Initialize( IDistributedDataset& data );
538+
void SetErrorMessage( int threadIndex, CString message );
521539
};
522540

523541
CDistributedInference::CThreadParams::CThreadParams( int threadsCount, CReferenceDnnFactory& referenceDnnFactory )
@@ -546,6 +564,13 @@ void CDistributedInference::CThreadParams::Initialize( IDistributedDataset& data
546564
IsErrorHappened = false;
547565
}
548566

567+
void CDistributedInference::CThreadParams::SetErrorMessage( int threadIndex, CString message )
568+
{
569+
IsErrorHappened = true;
570+
ErrorMessages[threadIndex] = std::move( message );
571+
ErrorMessages[threadIndex] += "(thread = " + Str( threadIndex ) + ")";
572+
}
573+
549574
//---------------------------------------------------------------------------------------------------------------------
550575

551576
CDistributedInference::CDistributedInference( const CDnn& dnn, int threadsCount,
@@ -588,13 +613,11 @@ void CDistributedInference::RunOnce( IDistributedDataset& data )
588613
threadParams.IsDnnInferenced[threadIndex] = true;
589614
}
590615
} catch( std::exception& e ) {
591-
threadParams.IsErrorHappened = true;
592-
threadParams.ErrorMessages[threadIndex] = e.what();
616+
threadParams.SetErrorMessage( threadIndex, e.what() );
593617
}
594618
#ifdef NEOML_USE_FINEOBJ
595619
catch( CException* e ) {
596-
threadParams.IsErrorHappened = true;
597-
threadParams.ErrorMessages[threadIndex] = e->MessageText().CreateString();
620+
threadParams.SetErrorMessage( threadIndex, e->MessageText().CreateString() );
598621
delete e;
599622
}
600623
#endif // NEOML_USE_FINEOBJ
@@ -608,12 +631,12 @@ void CDistributedInference::RunOnce( IDistributedDataset& data )
608631

609632
void CDistributedInference::GetLastBlob( const CString& layerName, CObjectArray<const CDnnBlob>& blobs ) const
610633
{
611-
blobs.DeleteAll();
612-
blobs.Add( nullptr, threadParams->Refs.Size() );
634+
blobs.SetSize( threadParams->Refs.Size() );
635+
// Return blobs for all models
613636
for( int i = 0; i < threadParams->Refs.Size(); ++i ) {
614-
if( threadParams->IsDnnInferenced[i] ) {
615-
blobs[i] = CheckCast<const CSinkLayer>( threadParams->Refs[i]->Dnn.GetLayer( layerName ) )->GetBlob();
616-
}
637+
blobs[i] = ( threadParams->IsDnnInferenced[i] )
638+
? CheckCast<const CSinkLayer>( threadParams->Refs[i]->Dnn.GetLayer( layerName ) )->GetBlob()
639+
: nullptr;
617640
}
618641
}
619642

0 commit comments

Comments
 (0)