Skip to content

Commit 10a5d73

Browse files
committed
[NeoML] CUDA sync in DnnSolver::clipGradients
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 51a4fea commit 10a5d73

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

NeoML/src/Dnn/DnnSolver.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,16 @@ void CDnnSolver::clipGradients( const CObjectArray<CDnnBlob>& paramDiffBlobs )
285285
paramDiffBlobs[i]->GetDataSize(), tempVar );
286286
MathEngine().VectorAdd( gradVar, tempVar, gradVar, 1 );
287287
}
288-
NeoPresume( std::isfinite( gradVar.GetValue() ) );
289-
MathEngine().VectorSqrt( gradVar, gradVar, 1 );
290288

289+
float grad = gradVar.GetValue(); // CUDA sync
290+
NeoPresume( std::isfinite( grad ) );
291291
// Calculate scale
292-
MathEngine().VectorMax( gradVar, maxGradientNorm, gradVar, 1 );
293-
MathEngine().VectorInv( gradVar, tempVar, 1 );
294-
MathEngine().VectorMultiply( tempVar, tempVar, 1, maxGradientNorm );
292+
grad = maxGradientNorm / std::max( sqrtf( grad ), maxGradientNorm );
295293

296294
// Decrease the gradient
297295
for( int i = 0; i < paramDiffBlobs.Size(); ++i ) {
298296
MathEngine().VectorMultiply( paramDiffBlobs[i]->GetData(), paramDiffBlobs[i]->GetData(),
299-
paramDiffBlobs[i]->GetDataSize(), tempVar );
297+
paramDiffBlobs[i]->GetDataSize(), grad );
300298
}
301299
}
302300

@@ -934,7 +932,6 @@ void CDnnLambGradientSolver::TrainLayer( const CBaseLayer* layer, const CObjectA
934932
// Add squared L2-norm for calculation of L2-norm of the whole mode
935933
if( useNvLamb ) {
936934
const float invSquareClipMultiplier = 1.0f / ( clipMultiplier * clipMultiplier );
937-
//normL2Var->GetData().SetValue( 0.f ); // CUDA sync
938935
MathEngine().VectorSum( TempData(), dataSize, normL2Var->GetData() );
939936
const float layerNormL2 = normL2Var->GetData().GetValue(); // CUDA sync
940937
layersGradientNormSquare.Add( invSquareClipMultiplier * layerNormL2 );
@@ -977,10 +974,8 @@ float CDnnLambGradientSolver::calcL2NormAverage( const CConstFloatHandle& data,
977974
const float multiplier( 1.f / dataSize );
978975
MathEngine().VectorMultiply( data, tempNormBlob->GetData(), dataSize, multiplier );
979976

980-
//normL2Var->GetData().SetValue( 0.f ); // CUDA sync
981977
MathEngine().VectorDotProduct( tempNormBlob->GetData(), tempNormBlob->GetData(), dataSize, normL2Var->GetData() );
982-
MathEngine().VectorSqrt( normL2Var->GetData(), normL2Var->GetData(), 1 );
983-
return normL2Var->GetData().GetValue(); // CUDA sync
978+
return sqrtf( normL2Var->GetData().GetValue() ); // CUDA sync
984979
}
985980

986981
// Parameter indices, used in weightDecay

0 commit comments

Comments
 (0)