Skip to content

Commit 76d935f

Browse files
committed
[NeoML] CCenterLossLayer mem-optimize
Signed-off-by: Kirill Golikov <[email protected]>
1 parent b35445c commit 76d935f

File tree

2 files changed

+46
-50
lines changed

2 files changed

+46
-50
lines changed

NeoML/include/NeoML/Dnn/Layers/CenterLossLayer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -67,7 +67,7 @@ class NEOML_API CCenterLossLayer : public CLossLayer {
6767
// The internal blobs
6868
CPtr<CDnnBlob> classCentersBlob;
6969

70-
void updateCenters(const CFloatHandle& tempDiffHandle);
70+
void updateCenters( const CConstFloatHandle& tempDiff );
7171
};
7272

7373
NEOML_API CLayerWrapper<CCenterLossLayer> CenterLoss( int numberOfClasses,

NeoML/src/Dnn/Layers/CenterLossLayer.cpp

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -31,11 +31,11 @@ CCenterLossLayer::CCenterLossLayer( IMathEngine& mathEngine ) :
3131
oneMult->GetData().SetValue( 1.f );
3232
}
3333

34-
static const int CenterLossLayerVersion = 2000;
34+
constexpr int centerLossLayerVersion = 2000;
3535

3636
void CCenterLossLayer::Serialize( CArchive& archive )
3737
{
38-
archive.SerializeVersion( CenterLossLayerVersion, CDnn::ArchiveMinSupportedVersion );
38+
archive.SerializeVersion( centerLossLayerVersion, CDnn::ArchiveMinSupportedVersion );
3939
CLossLayer::Serialize( archive );
4040

4141
archive.Serialize( numberOfClasses );
@@ -59,86 +59,82 @@ void CCenterLossLayer::Reshape()
5959
}
6060

6161
void CCenterLossLayer::BatchCalculateLossAndGradient( int batchSize, CConstFloatHandle data, int vectorSize,
62-
CConstIntHandle label, int /* labelSize */, CFloatHandle lossValue, CFloatHandle lossGradient )
62+
CConstIntHandle label, int /*labelSize*/, CFloatHandle lossValue, CFloatHandle lossGradient )
6363
{
6464
// The total input size
6565
const int inputDataSize = batchSize * vectorSize;
6666

67-
if(classCentersBlob == 0) {
67+
if(classCentersBlob == nullptr) {
6868
classCentersBlob = CDnnBlob::CreateMatrix(MathEngine(), CT_Float, numberOfClasses, vectorSize);
69-
classCentersBlob->Fill<float>( 0.f );
69+
classCentersBlob->Fill( 0.f );
7070
}
7171
// The current class centers
72-
CConstFloatHandle classCenters = classCentersBlob->GetData<float>();
72+
CConstFloatHandle classCenters = classCentersBlob->GetData();
7373
// Remember the difference between the input features and the current class centers
7474
// for these objects according to their labels: x_i - c_{y_i}
75-
CFloatHandleVar tempDiffHandle(MathEngine(), inputDataSize);
75+
CFloatHandleStackVar tempDiff( MathEngine(), inputDataSize );
7676

7777
// Copy the current center values for the input classes
78-
CLookupDimension lookupDimension;
79-
lookupDimension.VectorCount = numberOfClasses;
80-
lookupDimension.VectorSize = vectorSize;
78+
CLookupDimension lookupDimension( numberOfClasses, vectorSize );
8179
MathEngine().VectorMultichannelLookupAndCopy( batchSize, 1, label, &classCenters, &lookupDimension, 1,
82-
tempDiffHandle.GetHandle(), vectorSize );
83-
80+
tempDiff, vectorSize );
8481
// Remember the difference between the calculated features and the current centers for these objects
85-
MathEngine().VectorSub( data, tempDiffHandle.GetHandle(), tempDiffHandle.GetHandle(), inputDataSize );
86-
87-
// Calculate the squared difference from above and the error on the elements
88-
CFloatHandleVar diffSquared(MathEngine(), inputDataSize);
89-
MathEngine().VectorEltwiseMultiply( tempDiffHandle.GetHandle(), tempDiffHandle.GetHandle(), diffSquared.GetHandle(), inputDataSize );
90-
MathEngine().SumMatrixColumns( lossValue, diffSquared.GetHandle(), batchSize, vectorSize );
82+
MathEngine().VectorSub( data, tempDiff, tempDiff, inputDataSize );
9183

9284
// When not learning, that is, running the network to get the current loss value,
9385
// there is no need to calculate loss gradient and update the centers
94-
if( lossGradient.IsNull() ) {
95-
return;
86+
if( !lossGradient.IsNull() ) {
87+
// The x_i - c_{y_i} value is the same as derivative by the inputs
88+
MathEngine().VectorCopy( lossGradient, tempDiff, inputDataSize );
89+
// Update the class centers
90+
updateCenters( tempDiff );
9691
}
97-
// The x_i - c_{y_i} value is the same as derivative by the inputs
98-
MathEngine().VectorCopy( lossGradient, tempDiffHandle.GetHandle(), tempDiffHandle.Size() );
9992

100-
// Update the class centers
101-
updateCenters( tempDiffHandle.GetHandle());
93+
CFloatHandle tempDiffSquared = tempDiff;
94+
// Calculate the squared difference from above and the error on the elements
95+
MathEngine().VectorEltwiseMultiply( tempDiff, tempDiff, tempDiffSquared, inputDataSize );
96+
MathEngine().SumMatrixColumns( lossValue, tempDiffSquared, batchSize, vectorSize );
10297
}
10398

10499
// Update the class centers on the backward pass using the current batch data
105-
void CCenterLossLayer::updateCenters(const CFloatHandle& tempDiffHandle)
100+
void CCenterLossLayer::updateCenters( const CConstFloatHandle& tempDiff )
106101
{
102+
const int inputSize = inputBlobs[0]->GetDataSize();
107103
const int objectCount = inputBlobs[0]->GetObjectCount();
108104
const int numberOfFeatures = inputBlobs[0]->GetObjectSize();
105+
const int classCentersSize = classCentersBlob->GetDataSize();
109106

110-
CFloatHandle classCenters = classCentersBlob->GetData<float>();
107+
CFloatHandle classCenters = classCentersBlob->GetData();
111108
CConstIntHandle labels = inputBlobs[1]->GetData<int>();
112109

113-
CLookupDimension lookupDimension;
114-
lookupDimension.VectorCount = numberOfClasses;
115-
lookupDimension.VectorSize = numberOfFeatures;
110+
CFloatHandleStackVar temp( MathEngine(), classCentersSize * 2 + inputSize );
111+
CFloatHandle classCentersNumerator = temp;
112+
CFloatHandle classCentersDenominator = temp + classCentersSize;
113+
CFloatHandle onesTempBlob = temp + classCentersSize * 2;
116114
CFloatHandle handlesArray[1];
115+
117116
// The numerator of the correction: the total of x_i - c_{y_i}, aggregated by classes
118-
CFloatHandleVar classCentersUpdatesNumerator(MathEngine(), classCentersBlob->GetDataSize());
119-
MathEngine().VectorFill(classCentersUpdatesNumerator.GetHandle(), 0.0f, classCentersUpdatesNumerator.Size());
120-
handlesArray[0] = classCentersUpdatesNumerator.GetHandle();
117+
MathEngine().VectorFill(classCentersNumerator, 0.0f, classCentersSize);
118+
handlesArray[0] = classCentersNumerator;
121119

122-
MathEngine().VectorMultichannelLookupAndAddToTable( objectCount, 1, labels,
123-
handlesArray, &lookupDimension, 1, oneMult->GetData(), tempDiffHandle, numberOfFeatures );
120+
CLookupDimension lookupDimension( /*count*/numberOfClasses, /*size*/numberOfFeatures );
121+
MathEngine().VectorMultichannelLookupAndAddToTable( objectCount, 1, labels,
122+
handlesArray, &lookupDimension, 1, oneMult->GetData(), tempDiff, numberOfFeatures );
124123

125-
CFloatHandleVar onesTemporaryBlob(MathEngine(), inputBlobs[0]->GetDataSize());
126-
MathEngine().VectorFill(onesTemporaryBlob.GetHandle(), 1.0f, onesTemporaryBlob.Size());
124+
MathEngine().VectorFill( onesTempBlob, 1.0f, inputSize );
127125
// The denominator of the correction: 1 + the number of elements of this class in the batch
128-
CFloatHandleVar classCentersUpdatesDenominator(MathEngine(), classCentersBlob->GetDataSize());
129-
MathEngine().VectorFill(classCentersUpdatesDenominator.GetHandle(), 1.0f, classCentersUpdatesDenominator.Size());
130-
handlesArray[0] = classCentersUpdatesDenominator.GetHandle();
126+
MathEngine().VectorFill(classCentersDenominator, 1.0f, classCentersSize);
127+
handlesArray[0] = classCentersDenominator;
131128

132-
MathEngine().VectorMultichannelLookupAndAddToTable( objectCount, 1, labels,
133-
handlesArray, &lookupDimension, 1, oneMult->GetData(), onesTemporaryBlob.GetHandle(), numberOfFeatures );
129+
MathEngine().VectorMultichannelLookupAndAddToTable( objectCount, 1, labels,
130+
handlesArray, &lookupDimension, 1, oneMult->GetData(), onesTempBlob, numberOfFeatures );
134131

135132
// The final correction = \alpha * numerator / denominator
136-
MathEngine().VectorEltwiseDivide( classCentersUpdatesNumerator.GetHandle(), classCentersUpdatesDenominator.GetHandle(),
137-
classCentersUpdatesNumerator.GetHandle(), classCentersBlob->GetDataSize() );
138-
MathEngine().VectorMultiply( classCentersUpdatesNumerator.GetHandle(), classCentersUpdatesNumerator.GetHandle(),
139-
classCentersBlob->GetDataSize(), classCentersConvergenceRate->GetData() );
140-
MathEngine().VectorAdd( classCenters, classCentersUpdatesNumerator.GetHandle(), classCenters,
141-
classCentersBlob->GetDataSize() );
133+
MathEngine().VectorEltwiseDivide( classCentersNumerator, classCentersDenominator,
134+
classCentersNumerator, classCentersSize );
135+
MathEngine().VectorMultiply( classCentersNumerator, classCentersNumerator,
136+
classCentersSize, classCentersConvergenceRate->GetData() );
137+
MathEngine().VectorAdd( classCenters, classCentersNumerator, classCenters, classCentersSize );
142138
}
143139

144140
CLayerWrapper<CCenterLossLayer> CenterLoss(

0 commit comments

Comments
 (0)