Skip to content

Commit 7fd4b05

Browse files
committed
[NeoML] CBinaryCrossEntropyLossLayer mem-optimize
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 95296eb commit 7fd4b05

File tree

2 files changed

+47
-56
lines changed

2 files changed

+47
-56
lines changed

NeoML/include/NeoML/Dnn/Layers/LossLayer.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,13 @@ NEOML_API CLayerWrapper<CCrossEntropyLossLayer> CrossEntropyLoss(
174174
class NEOML_API CBinaryCrossEntropyLossLayer : public CLossLayer {
175175
NEOML_DNN_LAYER( CBinaryCrossEntropyLossLayer )
176176
public:
177-
explicit CBinaryCrossEntropyLossLayer( IMathEngine& mathEngine );
177+
explicit CBinaryCrossEntropyLossLayer( IMathEngine& mathEngine ) :
178+
CLossLayer( mathEngine, "CCnnBinaryCrossEntropyLossLayer" ) {}
178179

179180
// The weight for the positive side of the sigmoid
180181
// Values over 1 increase recall, values below 1 increase precision
181-
void SetPositiveWeight( float value );
182-
float GetPositiveWeight() const;
182+
void SetPositiveWeight( float value ) { positiveWeightMinusOne = value - 1; }
183+
float GetPositiveWeight() const { return positiveWeightMinusOne + 1; }
183184

184185
void Serialize( CArchive& archive ) override;
185186

@@ -191,9 +192,9 @@ class NEOML_API CBinaryCrossEntropyLossLayer : public CLossLayer {
191192

192193
private:
193194
// constants used for calculating the function value
194-
float positiveWeightMinusOneValue;
195+
float positiveWeightMinusOne = 0;
195196

196-
void calculateStableSigmoid( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) const;
197+
void calculateStableSigmoid( const CFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) const;
197198
};
198199

199200
NEOML_API CLayerWrapper<CBinaryCrossEntropyLossLayer> BinaryCrossEntropyLoss(

NeoML/src/Dnn/Layers/BinaryCrossEntropyLayer.cpp

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -20,22 +20,6 @@ limitations under the License.
2020

2121
namespace NeoML {
2222

23-
CBinaryCrossEntropyLossLayer::CBinaryCrossEntropyLossLayer( IMathEngine& mathEngine ) :
24-
CLossLayer( mathEngine, "CCnnBinaryCrossEntropyLossLayer" ),
25-
positiveWeightMinusOneValue( 0 )
26-
{
27-
}
28-
29-
void CBinaryCrossEntropyLossLayer::SetPositiveWeight( float value )
30-
{
31-
positiveWeightMinusOneValue = value - 1;
32-
}
33-
34-
float CBinaryCrossEntropyLossLayer::GetPositiveWeight() const
35-
{
36-
return positiveWeightMinusOneValue + 1;
37-
}
38-
3923
void CBinaryCrossEntropyLossLayer::Reshape()
4024
{
4125
CLossLayer::Reshape();
@@ -44,8 +28,8 @@ void CBinaryCrossEntropyLossLayer::Reshape()
4428
"BinaryCrossEntropy layer can only work with a binary classificaion problem" );
4529
}
4630

47-
void CBinaryCrossEntropyLossLayer::BatchCalculateLossAndGradient( int batchSize, CConstFloatHandle data, int /* vectorSize */,
48-
CConstFloatHandle label, int /* labelSize */, CFloatHandle lossValue, CFloatHandle lossGradient )
31+
void CBinaryCrossEntropyLossLayer::BatchCalculateLossAndGradient( int batchSize, CConstFloatHandle data, int /*vectorSize*/,
32+
CConstFloatHandle label, int /*labelSize*/, CFloatHandle lossValue, CFloatHandle lossGradient )
4933
{
5034
// Therefore the labels vector can only contain {-1, 1} values
5135
CFloatHandleStackVar one( MathEngine() );
@@ -56,22 +40,23 @@ void CBinaryCrossEntropyLossLayer::BatchCalculateLossAndGradient( int batchSize,
5640
minusOne.SetValue( -1.f );
5741
CFloatHandleStackVar zero( MathEngine() );
5842
zero.SetValue( 0.f );
59-
CFloatHandleStackVar positiveWeightMinusOne( MathEngine() );
60-
positiveWeightMinusOne.SetValue( positiveWeightMinusOneValue );
43+
CFloatHandleStackVar positiveWeightMinusOneVar( MathEngine() );
44+
positiveWeightMinusOneVar.SetValue( positiveWeightMinusOne );
6145

46+
CFloatHandleStackVar temp( MathEngine(), batchSize * 3 );
6247
// Convert the target values to [0, 1] range using the binaryLabel = 0.5 * ( label + 1 ) formula
63-
CFloatHandleStackVar binaryLabel( MathEngine(), batchSize );
48+
CFloatHandle binaryLabel = temp.GetHandle();
6449
MathEngine().VectorAddValue( label, binaryLabel, batchSize, one );
6550
MathEngine().VectorMultiply( binaryLabel, binaryLabel, batchSize, half );
6651

6752
// Notations:
68-
// x = logits, z = labels, q = pos_weight, l = 1 + (q - 1) * z
53+
// x = logits, z = labels, q = pos_weight, lCoef = 1 + (q - 1) * z
6954

7055
// The original loss function formula:
71-
// loss = (1 - z) * x + l * log(1 + exp(-x))
56+
// loss = (1 - z) * x + lCoef * log(1 + exp(-x))
7257

7358
// The formula to avoid overflow for large exponent power in exp(-x):
74-
// loss = (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
59+
// loss = (1 - z) * x + lCoef * (log(1 + exp(-abs(x))) + max(-x, 0))
7560

7661
// (1-z)*x
7762
CFloatHandleStackVar temp( MathEngine(), batchSize);
@@ -104,11 +89,11 @@ void CBinaryCrossEntropyLossLayer::BatchCalculateLossAndGradient( int batchSize,
10489
MathEngine().VectorAdd( lossValue, temp, lossValue, batchSize );
10590

10691
if( !lossGradient.IsNull() ) {
107-
// loss' = (1-z) - l / ( 1+exp(x) ) = (1-z) - l * sigmoid(-x)
10892

10993
// (z-1)
11094
CFloatHandleStackVar temp5( MathEngine(), batchSize );
11195
MathEngine().VectorAddValue( binaryLabel, temp5, batchSize, minusOne );
96+
// loss' = (1 - z) - lCoef / ( 1 + exp(x) ) = (1 - z) - lCoef * sigmoid(-x)
11297

11398
// -x
11499
CFloatHandleStackVar temp6( MathEngine(), batchSize );
@@ -130,55 +115,60 @@ void CBinaryCrossEntropyLossLayer::BatchCalculateLossAndGradient( int batchSize,
130115
}
131116

132117
// Overflow-safe sigmoid calculation
133-
void CBinaryCrossEntropyLossLayer::calculateStableSigmoid( const CConstFloatHandle& firstHandle,
118+
void CBinaryCrossEntropyLossLayer::calculateStableSigmoid( const CFloatHandle& firstHandle,
134119
const CFloatHandle& resultHandle, int vectorSize ) const
135120
{
136121
CFloatHandleStackVar one( MathEngine() );
137122
one.SetValue( 1.f );
138123
CFloatHandleStackVar zero( MathEngine() );
139124
zero.SetValue( 0.f );
140125

126+
NeoPresume( !firstHandle.IsNull() );
127+
NeoPresume( !resultHandle.IsNull() );
128+
NeoPresume( firstHandle != resultHandle );
129+
// reduced memory usage for calculation
130+
CFloatHandle numerator = resultHandle;
131+
CFloatHandle denominator = firstHandle;
132+
141133
// The sigmoid formula:
142-
// Sigmoid(x) = 1 / (1 + e^-x )
134+
// Sigmoid(x) = 1 / ( 1 + e^-x )
143135

144136
// The formula to avoid overflow for large exponent power in exp(-x):
145-
// Sigmoid(x) = e^(-max(-x, 0) ) / ( 1 + e^-|x| )
137+
// Sigmoid(x) = e^( -max(-x, 0) ) / ( 1 + e^-|x| )
146138

147-
// e^(-max(-x, 0) )
148-
CFloatHandleStackVar temp( MathEngine(), vectorSize );
149-
MathEngine().VectorNegMultiply( firstHandle, temp, vectorSize, one );
150-
MathEngine().VectorReLU( temp, temp, vectorSize, zero );
151-
MathEngine().VectorNegMultiply( temp, temp, vectorSize, one );
152-
MathEngine().VectorExp( temp, temp, vectorSize );
139+
// e^( -max(-x, 0) )
140+
MathEngine().VectorNegMultiply( firstHandle, numerator, vectorSize, one );
141+
MathEngine().VectorReLU( numerator, numerator, vectorSize, zero );
142+
MathEngine().VectorNegMultiply( numerator, numerator, vectorSize, one );
143+
MathEngine().VectorExp( numerator, numerator, vectorSize );
153144

154145
// ( 1 + e^-|x| )
155-
CFloatHandleStackVar temp2( MathEngine(), vectorSize );
156-
MathEngine().VectorAbs( firstHandle, temp2, vectorSize );
157-
MathEngine().VectorNegMultiply( temp2, temp2, vectorSize, one );
158-
MathEngine().VectorExp( temp2, temp2, vectorSize );
159-
MathEngine().VectorAddValue( temp2, temp2, vectorSize, one );
146+
MathEngine().VectorAbs( firstHandle, denominator, vectorSize );
147+
MathEngine().VectorNegMultiply( denominator, denominator, vectorSize, one );
148+
MathEngine().VectorExp( denominator, denominator, vectorSize );
149+
MathEngine().VectorAddValue( denominator, denominator, vectorSize, one );
160150

161151
// The sigmoid
162-
MathEngine().VectorEltwiseDivide( temp, temp2, resultHandle, vectorSize );
152+
MathEngine().VectorEltwiseDivide( numerator, denominator, resultHandle, vectorSize );
163153
}
164154

165-
static const int BinaryCrossEntropyLossLayerVersion = 2000;
155+
constexpr int binaryCrossEntropyLossLayerVersion = 2000;
166156

167157
void CBinaryCrossEntropyLossLayer::Serialize( CArchive& archive )
168158
{
169-
archive.SerializeVersion( BinaryCrossEntropyLossLayerVersion, CDnn::ArchiveMinSupportedVersion );
159+
archive.SerializeVersion( binaryCrossEntropyLossLayerVersion, CDnn::ArchiveMinSupportedVersion );
170160
CLossLayer::Serialize( archive );
171-
172-
archive.Serialize( positiveWeightMinusOneValue );
161+
162+
archive.Serialize( positiveWeightMinusOne );
173163
}
174164

175-
CLayerWrapper<CBinaryCrossEntropyLossLayer> BinaryCrossEntropyLoss(
176-
float positiveWeight, float lossWeight )
165+
CLayerWrapper<CBinaryCrossEntropyLossLayer> BinaryCrossEntropyLoss( float positiveWeight, float lossWeight )
177166
{
178-
return CLayerWrapper<CBinaryCrossEntropyLossLayer>( "BinaryCrossEntropyLoss", [=]( CBinaryCrossEntropyLossLayer* result ) {
179-
result->SetPositiveWeight( positiveWeight );
180-
result->SetLossWeight( lossWeight );
181-
} );
167+
return CLayerWrapper<CBinaryCrossEntropyLossLayer>( "BinaryCrossEntropyLoss",
168+
[=]( CBinaryCrossEntropyLossLayer* result ) {
169+
result->SetPositiveWeight( positiveWeight );
170+
result->SetLossWeight( lossWeight );
171+
} );
182172
}
183173

184174
} // namespace NeoML

0 commit comments

Comments
 (0)