1
- /* Copyright © 2017-2020 ABBYY Production LLC
1
+ /* Copyright © 2017-2024 ABBYY
2
2
3
3
Licensed under the Apache License, Version 2.0 (the "License");
4
4
you may not use this file except in compliance with the License.
@@ -31,11 +31,11 @@ CCenterLossLayer::CCenterLossLayer( IMathEngine& mathEngine ) :
31
31
oneMult->GetData ().SetValue ( 1 .f );
32
32
}
33
33
34
- static const int CenterLossLayerVersion = 2000 ;
34
+ constexpr int centerLossLayerVersion = 2000 ;
35
35
36
36
void CCenterLossLayer::Serialize ( CArchive& archive )
37
37
{
38
- archive.SerializeVersion ( CenterLossLayerVersion , CDnn::ArchiveMinSupportedVersion );
38
+ archive.SerializeVersion ( centerLossLayerVersion , CDnn::ArchiveMinSupportedVersion );
39
39
CLossLayer::Serialize ( archive );
40
40
41
41
archive.Serialize ( numberOfClasses );
@@ -59,86 +59,82 @@ void CCenterLossLayer::Reshape()
59
59
}
60
60
61
61
void CCenterLossLayer::BatchCalculateLossAndGradient ( int batchSize, CConstFloatHandle data, int vectorSize,
62
- CConstIntHandle label, int /* labelSize */ , CFloatHandle lossValue, CFloatHandle lossGradient )
62
+ CConstIntHandle label, int /* labelSize*/ , CFloatHandle lossValue, CFloatHandle lossGradient )
63
63
{
64
64
// The total input size
65
65
const int inputDataSize = batchSize * vectorSize;
66
66
67
- if (classCentersBlob == 0 ) {
67
+ if (classCentersBlob == nullptr ) {
68
68
classCentersBlob = CDnnBlob::CreateMatrix (MathEngine (), CT_Float, numberOfClasses, vectorSize);
69
- classCentersBlob->Fill < float > ( 0 .f );
69
+ classCentersBlob->Fill ( 0 .f );
70
70
}
71
71
// The current class centers
72
- CConstFloatHandle classCenters = classCentersBlob->GetData < float > ();
72
+ CConstFloatHandle classCenters = classCentersBlob->GetData ();
73
73
// Remember the difference between the input features and the current class centers
74
74
// for these objects according to their labels: x_i - c_{y_i}
75
- CFloatHandleVar tempDiffHandle ( MathEngine (), inputDataSize);
75
+ CFloatHandleStackVar tempDiff ( MathEngine (), inputDataSize );
76
76
77
77
// Copy the current center values for the input classes
78
- CLookupDimension lookupDimension;
79
- lookupDimension.VectorCount = numberOfClasses;
80
- lookupDimension.VectorSize = vectorSize;
78
+ CLookupDimension lookupDimension ( numberOfClasses, vectorSize );
81
79
MathEngine ().VectorMultichannelLookupAndCopy ( batchSize, 1 , label, &classCenters, &lookupDimension, 1 ,
82
- tempDiffHandle.GetHandle (), vectorSize );
83
-
80
+ tempDiff, vectorSize );
84
81
// Remember the difference between the calculated features and the current centers for these objects
85
- MathEngine ().VectorSub ( data, tempDiffHandle.GetHandle (), tempDiffHandle.GetHandle (), inputDataSize );
86
-
87
- // Calculate the squared difference from above and the error on the elements
88
- CFloatHandleVar diffSquared (MathEngine (), inputDataSize);
89
- MathEngine ().VectorEltwiseMultiply ( tempDiffHandle.GetHandle (), tempDiffHandle.GetHandle (), diffSquared.GetHandle (), inputDataSize );
90
- MathEngine ().SumMatrixColumns ( lossValue, diffSquared.GetHandle (), batchSize, vectorSize );
82
+ MathEngine ().VectorSub ( data, tempDiff, tempDiff, inputDataSize );
91
83
92
84
// When not learning, that is, running the network to get the current loss value,
93
85
// there is no need to calculate loss gradient and update the centers
94
- if ( lossGradient.IsNull () ) {
95
- return ;
86
+ if ( !lossGradient.IsNull () ) {
87
+ // The x_i - c_{y_i} value is the same as derivative by the inputs
88
+ MathEngine ().VectorCopy ( lossGradient, tempDiff, inputDataSize );
89
+ // Update the class centers
90
+ updateCenters ( tempDiff );
96
91
}
97
- // The x_i - c_{y_i} value is the same as derivative by the inputs
98
- MathEngine ().VectorCopy ( lossGradient, tempDiffHandle.GetHandle (), tempDiffHandle.Size () );
99
92
100
- // Update the class centers
101
- updateCenters ( tempDiffHandle.GetHandle ());
93
+ CFloatHandle tempDiffSquared = tempDiff;
94
+ // Calculate the squared difference from above and the error on the elements
95
+ MathEngine ().VectorEltwiseMultiply ( tempDiff, tempDiff, tempDiffSquared, inputDataSize );
96
+ MathEngine ().SumMatrixColumns ( lossValue, tempDiffSquared, batchSize, vectorSize );
102
97
}
103
98
104
99
// Update the class centers on the backward pass using the current batch data
105
- void CCenterLossLayer::updateCenters (const CFloatHandle& tempDiffHandle )
100
+ void CCenterLossLayer::updateCenters ( const CConstFloatHandle& tempDiff )
106
101
{
102
+ const int inputSize = inputBlobs[0 ]->GetDataSize ();
107
103
const int objectCount = inputBlobs[0 ]->GetObjectCount ();
108
104
const int numberOfFeatures = inputBlobs[0 ]->GetObjectSize ();
105
+ const int classCentersSize = classCentersBlob->GetDataSize ();
109
106
110
- CFloatHandle classCenters = classCentersBlob->GetData < float > ();
107
+ CFloatHandle classCenters = classCentersBlob->GetData ();
111
108
CConstIntHandle labels = inputBlobs[1 ]->GetData <int >();
112
109
113
- CLookupDimension lookupDimension;
114
- lookupDimension.VectorCount = numberOfClasses;
115
- lookupDimension.VectorSize = numberOfFeatures;
110
+ CFloatHandleStackVar temp ( MathEngine (), classCentersSize * 2 + inputSize );
111
+ CFloatHandle classCentersNumerator = temp;
112
+ CFloatHandle classCentersDenominator = temp + classCentersSize;
113
+ CFloatHandle onesTempBlob = temp + classCentersSize * 2 ;
116
114
CFloatHandle handlesArray[1 ];
115
+
117
116
// The numerator of the correction: the total of x_i - c_{y_i}, aggregated by classes
118
- CFloatHandleVar classCentersUpdatesNumerator (MathEngine (), classCentersBlob->GetDataSize ());
119
- MathEngine ().VectorFill (classCentersUpdatesNumerator.GetHandle (), 0 .0f , classCentersUpdatesNumerator.Size ());
120
- handlesArray[0 ] = classCentersUpdatesNumerator.GetHandle ();
117
+ MathEngine ().VectorFill (classCentersNumerator, 0 .0f , classCentersSize);
118
+ handlesArray[0 ] = classCentersNumerator;
121
119
122
- MathEngine ().VectorMultichannelLookupAndAddToTable ( objectCount, 1 , labels,
123
- handlesArray, &lookupDimension, 1 , oneMult->GetData (), tempDiffHandle, numberOfFeatures );
120
+ CLookupDimension lookupDimension ( /* count*/ numberOfClasses, /* size*/ numberOfFeatures );
121
+ MathEngine ().VectorMultichannelLookupAndAddToTable ( objectCount, 1 , labels,
122
+ handlesArray, &lookupDimension, 1 , oneMult->GetData (), tempDiff, numberOfFeatures );
124
123
125
- CFloatHandleVar onesTemporaryBlob (MathEngine (), inputBlobs[0 ]->GetDataSize ());
126
- MathEngine ().VectorFill (onesTemporaryBlob.GetHandle (), 1 .0f , onesTemporaryBlob.Size ());
124
+ MathEngine ().VectorFill ( onesTempBlob, 1 .0f , inputSize );
127
125
// The denominator of the correction: 1 + the number of elements of this class in the batch
128
- CFloatHandleVar classCentersUpdatesDenominator (MathEngine (), classCentersBlob->GetDataSize ());
129
- MathEngine ().VectorFill (classCentersUpdatesDenominator.GetHandle (), 1 .0f , classCentersUpdatesDenominator.Size ());
130
- handlesArray[0 ] = classCentersUpdatesDenominator.GetHandle ();
126
+ MathEngine ().VectorFill (classCentersDenominator, 1 .0f , classCentersSize);
127
+ handlesArray[0 ] = classCentersDenominator;
131
128
132
- MathEngine ().VectorMultichannelLookupAndAddToTable ( objectCount, 1 , labels,
133
- handlesArray, &lookupDimension, 1 , oneMult->GetData (), onesTemporaryBlob. GetHandle () , numberOfFeatures );
129
+ MathEngine ().VectorMultichannelLookupAndAddToTable ( objectCount, 1 , labels,
130
+ handlesArray, &lookupDimension, 1 , oneMult->GetData (), onesTempBlob , numberOfFeatures );
134
131
135
132
// The final correction = \alpha * numerator / denominator
136
- MathEngine ().VectorEltwiseDivide ( classCentersUpdatesNumerator.GetHandle (), classCentersUpdatesDenominator.GetHandle (),
137
- classCentersUpdatesNumerator.GetHandle (), classCentersBlob->GetDataSize () );
138
- MathEngine ().VectorMultiply ( classCentersUpdatesNumerator.GetHandle (), classCentersUpdatesNumerator.GetHandle (),
139
- classCentersBlob->GetDataSize (), classCentersConvergenceRate->GetData () );
140
- MathEngine ().VectorAdd ( classCenters, classCentersUpdatesNumerator.GetHandle (), classCenters,
141
- classCentersBlob->GetDataSize () );
133
+ MathEngine ().VectorEltwiseDivide ( classCentersNumerator, classCentersDenominator,
134
+ classCentersNumerator, classCentersSize );
135
+ MathEngine ().VectorMultiply ( classCentersNumerator, classCentersNumerator,
136
+ classCentersSize, classCentersConvergenceRate->GetData () );
137
+ MathEngine ().VectorAdd ( classCenters, classCentersNumerator, classCenters, classCentersSize );
142
138
}
143
139
144
140
CLayerWrapper<CCenterLossLayer> CenterLoss (
0 commit comments