@@ -285,18 +285,16 @@ void CDnnSolver::clipGradients( const CObjectArray<CDnnBlob>& paramDiffBlobs )
285
285
paramDiffBlobs[i]->GetDataSize (), tempVar );
286
286
MathEngine ().VectorAdd ( gradVar, tempVar, gradVar, 1 );
287
287
}
288
- NeoPresume ( std::isfinite ( gradVar.GetValue () ) );
289
- MathEngine ().VectorSqrt ( gradVar, gradVar, 1 );
290
288
289
+ float grad = gradVar.GetValue (); // CUDA sync
290
+ NeoPresume ( std::isfinite ( grad ) );
291
291
// Calculate scale
292
- MathEngine ().VectorMax ( gradVar, maxGradientNorm, gradVar, 1 );
293
- MathEngine ().VectorInv ( gradVar, tempVar, 1 );
294
- MathEngine ().VectorMultiply ( tempVar, tempVar, 1 , maxGradientNorm );
292
+ grad = maxGradientNorm / std::max ( sqrtf ( grad ), maxGradientNorm );
295
293
296
294
// Decrease the gradient
297
295
for ( int i = 0 ; i < paramDiffBlobs.Size (); ++i ) {
298
296
MathEngine ().VectorMultiply ( paramDiffBlobs[i]->GetData (), paramDiffBlobs[i]->GetData (),
299
- paramDiffBlobs[i]->GetDataSize (), tempVar );
297
+ paramDiffBlobs[i]->GetDataSize (), grad );
300
298
}
301
299
}
302
300
@@ -934,7 +932,6 @@ void CDnnLambGradientSolver::TrainLayer( const CBaseLayer* layer, const CObjectA
934
932
// Add squared L2-norm for calculation of L2-norm of the whole mode
935
933
if ( useNvLamb ) {
936
934
const float invSquareClipMultiplier = 1 .0f / ( clipMultiplier * clipMultiplier );
937
- // normL2Var->GetData().SetValue( 0.f ); // CUDA sync
938
935
MathEngine ().VectorSum ( TempData (), dataSize, normL2Var->GetData () );
939
936
const float layerNormL2 = normL2Var->GetData ().GetValue (); // CUDA sync
940
937
layersGradientNormSquare.Add ( invSquareClipMultiplier * layerNormL2 );
@@ -977,10 +974,8 @@ float CDnnLambGradientSolver::calcL2NormAverage( const CConstFloatHandle& data,
977
974
const float multiplier ( 1 .f / dataSize );
978
975
MathEngine ().VectorMultiply ( data, tempNormBlob->GetData (), dataSize, multiplier );
979
976
980
- // normL2Var->GetData().SetValue( 0.f ); // CUDA sync
981
977
MathEngine ().VectorDotProduct ( tempNormBlob->GetData (), tempNormBlob->GetData (), dataSize, normL2Var->GetData () );
982
- MathEngine ().VectorSqrt ( normL2Var->GetData (), normL2Var->GetData (), 1 );
983
- return normL2Var->GetData ().GetValue (); // CUDA sync
978
+ return sqrtf ( normL2Var->GetData ().GetValue () ); // CUDA sync
984
979
}
985
980
986
981
// Parameter indices, used in weightDecay
0 commit comments