|
| 1 | +/* Copyright © 2024 ABBYY |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +--------------------------------------------------------------------------------------------------------------*/ |
| 15 | + |
| 16 | +#include <common.h> |
| 17 | +#pragma hdrstop |
| 18 | + |
| 19 | +#include <TestFixture.h> |
| 20 | +#include <DnnSimpleTest.h> |
| 21 | + |
| 22 | +using namespace NeoML; |
| 23 | +using namespace NeoMLTest; |
| 24 | + |
| 25 | +namespace NeoMLTest { |
| 26 | + |
| 27 | +struct CCrossEntropyLossLayerTestParams final { |
| 28 | + CCrossEntropyLossLayerTestParams( float target, float result, float lossValue ) : |
| 29 | + Target( target ), Result( result ), LossValue( lossValue ) {} |
| 30 | + |
| 31 | + float Target; |
| 32 | + float Result; |
| 33 | + float LossValue; |
| 34 | +}; |
| 35 | + |
| 36 | +class CCrossEntropyLossLayerTest : |
| 37 | + public CNeoMLTestFixture, public ::testing::WithParamInterface<CCrossEntropyLossLayerTestParams> { |
| 38 | +public: |
| 39 | + static bool InitTestFixture() { return true; } |
| 40 | + static void DeinitTestFixture() {} |
| 41 | +}; |
| 42 | + |
| 43 | +} // namespace NeoMLTest |
| 44 | + |
| 45 | +//--------------------------------------------------------------------------------------------------------------------- |
| 46 | + |
| 47 | +TEST_F( CCrossEntropyLossLayerTest, ZeroBackwardDiffTest ) |
| 48 | +{ |
| 49 | + const auto met = MathEngine().GetType(); |
| 50 | + if( met != MET_Cpu && met != MET_Cuda ) { |
| 51 | + NEOML_HILIGHT( GTEST_LOG_( INFO ) ) << "Skipped rest of test for MathEngine type=" << met << " because no implementation.\n"; |
| 52 | + return; |
| 53 | + } |
| 54 | + |
| 55 | + const float rawData[8]{ |
| 56 | + 0.7f, 0.3f, |
| 57 | + 0.3f, 0.7f, |
| 58 | + 0.2f, 0.8f, |
| 59 | + 0.4f, 0.6f |
| 60 | + }; |
| 61 | + const float rawExpectedDiff[8]{ |
| 62 | + -0.10032f, 0.10032f, |
| 63 | + 0.10032f, -0.10032f, |
| 64 | + 0.08859f, -0.08859f, |
| 65 | + 0.0f, 0.0f |
| 66 | + }; |
| 67 | + const float rawFloatLabels[8]{ |
| 68 | + 1.f, 0.f, |
| 69 | + 0.f, 1.f, |
| 70 | + 0.f, 1.f, |
| 71 | + 0.f, 0.f |
| 72 | + }; |
| 73 | + const int rawIntLabels[4]{ 0, 1, 1, -1 }; |
| 74 | + |
| 75 | + CRandom random; |
| 76 | + CDnn dnn( random, MathEngine() ); |
| 77 | + |
| 78 | + CPtr<CSourceLayer> data = Source( dnn, "data" ); |
| 79 | + CPtr<CSourceLayer> label = Source( dnn, "label" ); |
| 80 | + |
| 81 | + CPtr<CDnnSimpleTestDummyLearningLayer> learn = new CDnnSimpleTestDummyLearningLayer( MathEngine() ); |
| 82 | + learn->SetName( "learn" ); |
| 83 | + learn->Connect( *data ); |
| 84 | + dnn.AddLayer( *learn ); |
| 85 | + |
| 86 | + CPtr<CCrossEntropyLossLayer> loss = CrossEntropyLoss()( learn.Ptr(), label.Ptr() ); |
| 87 | + |
| 88 | + CPtr<CDnnBlob> dataBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 2, 2, 2 ); |
| 89 | + dataBlob->CopyFrom( rawData ); |
| 90 | + data->SetBlob( dataBlob ); |
| 91 | + |
| 92 | + CPtr<CDnnBlob> expectedDiff = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 2, 2, 2 ); |
| 93 | + expectedDiff->CopyFrom( rawExpectedDiff ); |
| 94 | + learn->ExpectedDiff = expectedDiff; |
| 95 | + |
| 96 | + CPtr<CDnnBlob> labelBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 2, 2, 2 ); |
| 97 | + labelBlob->CopyFrom( rawFloatLabels ); |
| 98 | + label->SetBlob( labelBlob ); |
| 99 | + |
| 100 | + dnn.RunAndBackwardOnce(); |
| 101 | + |
| 102 | + labelBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Int, 2, 2, 1 ); |
| 103 | + labelBlob->CopyFrom( rawIntLabels ); |
| 104 | + label->SetBlob( labelBlob ); |
| 105 | + |
| 106 | + dnn.RunAndBackwardOnce(); |
| 107 | +} |
| 108 | + |
| 109 | +TEST_F( CCrossEntropyLossLayerTest, NoSoftmax ) |
| 110 | +{ |
| 111 | + const auto met = MathEngine().GetType(); |
| 112 | + if( met != MET_Cpu && met != MET_Cuda ) { |
| 113 | + NEOML_HILIGHT( GTEST_LOG_( INFO ) ) << "Skipped rest of test for MathEngine type=" << met << " because no implementation.\n"; |
| 114 | + return; |
| 115 | + } |
| 116 | + |
| 117 | + const float resultBuff[]{ 1.f / 15, 2.f / 15, 3.f / 15, 4.f / 15, 5.f / 15, |
| 118 | + 5.f / 15, 4.f / 15, 3.f / 15, 2.f / 15, 1.f / 15, |
| 119 | + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, |
| 120 | + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, |
| 121 | + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, |
| 122 | + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, |
| 123 | + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f |
| 124 | + }; |
| 125 | + |
| 126 | + CRandom random; |
| 127 | + CDnn dnn( random, MathEngine() ); |
| 128 | + |
| 129 | + CPtr<CSourceLayer> result = Source( dnn, "result" ); |
| 130 | + CPtr<CSourceLayer> target = Source( dnn, "target" ); |
| 131 | + |
| 132 | + CPtr<CDnnSimpleTestDummyLearningLayer> learn = new CDnnSimpleTestDummyLearningLayer( MathEngine() ); |
| 133 | + learn->Connect( *result ); |
| 134 | + dnn.AddLayer( *learn ); |
| 135 | + |
| 136 | + CPtr<CCrossEntropyLossLayer> loss = CrossEntropyLoss( /*softmax*/false )( learn.Ptr(), target.Ptr() ); |
| 137 | + |
| 138 | + CPtr<CDnnBlob> resultBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 7, 5 ); |
| 139 | + resultBlob->CopyFrom( resultBuff ); |
| 140 | + result->SetBlob( resultBlob ); |
| 141 | + |
| 142 | + CPtr<CDnnBlob> floatTargetBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 7, 5 ); |
| 143 | + floatTargetBlob->Fill( 0.f ); |
| 144 | + |
| 145 | + floatTargetBlob->GetData().SetValueAt( 2, 1 ); |
| 146 | + floatTargetBlob->GetObjectData( 1 ).SetValueAt( 4, 1 ); |
| 147 | + for( int i = 2; i < 7; ++i ) { |
| 148 | + floatTargetBlob->GetObjectData( i ).SetValueAt( 3, 1 ); |
| 149 | + } |
| 150 | + target->SetBlob( floatTargetBlob ); |
| 151 | + |
| 152 | + CPtr<CDnnBlob> expectedDiffBLob = resultBlob->GetCopy(); |
| 153 | + expectedDiffBLob->Fill( 1.f / 7 ); |
| 154 | + |
| 155 | + expectedDiffBLob->GetData().SetValueAt( 2, -4.f / 7 ); |
| 156 | + expectedDiffBLob->GetObjectData( 1 ).SetValueAt( 4, -2 ); |
| 157 | + for( int i = 2; i < 7; ++i ) { |
| 158 | + expectedDiffBLob->GetObjectData( i ).SetValueAt( 3, -4.f / 7 ); |
| 159 | + } |
| 160 | + learn->ExpectedDiff = expectedDiffBLob; |
| 161 | + |
| 162 | + dnn.RunAndBackwardOnce(); |
| 163 | + |
| 164 | + { |
| 165 | + dnn.DeleteLayer( *learn ); |
| 166 | + loss->Connect( 0, *result ); |
| 167 | + dnn.RunAndBackwardOnce(); |
| 168 | + |
| 169 | + CMemoryFile file; |
| 170 | + { |
| 171 | + CArchive archive( &file, CArchive::store ); |
| 172 | + archive.Serialize( dnn ); |
| 173 | + } |
| 174 | + file.SeekToBegin(); |
| 175 | + { |
| 176 | + CArchive archive( &file, CArchive::load ); |
| 177 | + archive.Serialize( dnn ); |
| 178 | + } |
| 179 | + |
| 180 | + result = CheckCast<CSourceLayer>( dnn.GetLayer( result->GetName() ) ); |
| 181 | + target = CheckCast<CSourceLayer>( dnn.GetLayer( target->GetName() ) ); |
| 182 | + loss = CheckCast<CCrossEntropyLossLayer>( dnn.GetLayer( loss->GetName() ) ); |
| 183 | + |
| 184 | + result->SetBlob( resultBlob ); |
| 185 | + |
| 186 | + learn->Connect( *result ); |
| 187 | + loss->Connect( *learn ); |
| 188 | + dnn.AddLayer( *learn ); |
| 189 | + } |
| 190 | + |
| 191 | + CPtr<CDnnBlob> intTargetBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Int, 1, 7, 1 ); |
| 192 | + intTargetBlob->FillObject<int>( 0, 2 ); |
| 193 | + intTargetBlob->FillObject<int>( 1, 4 ); |
| 194 | + for( int i = 2; i < 7; ++i ) { |
| 195 | + intTargetBlob->GetObjectData<int>( i ).SetValue( 3 ); |
| 196 | + } |
| 197 | + target->SetBlob( intTargetBlob ); |
| 198 | + |
| 199 | + dnn.RunAndBackwardOnce(); |
| 200 | +} |
| 201 | + |
| 202 | +TEST_P( CCrossEntropyLossLayerTest, BinaryCrossEntropyLossSignTest ) |
| 203 | +{ |
| 204 | + const CCrossEntropyLossLayerTestParams params = GetParam(); |
| 205 | + |
| 206 | + CRandom random; |
| 207 | + CDnn dnn( random, MathEngine() ); |
| 208 | + |
| 209 | + CTextStream debugOutput; |
| 210 | + dnn.SetLog( &debugOutput ); |
| 211 | + |
| 212 | + CSourceLayer* result = Source( dnn, "result" ); |
| 213 | + CSourceLayer* target = Source( dnn, "target" ); |
| 214 | + |
| 215 | + CPtr<CBinaryCrossEntropyLossLayer> loss = BinaryCrossEntropyLoss()( result, target ); |
| 216 | + |
| 217 | + CPtr<CDnnBlob> resultBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 1, 1 ); |
| 218 | + resultBlob->Fill( params.Result ); |
| 219 | + result->SetBlob( resultBlob ); |
| 220 | + |
| 221 | + CPtr<CDnnBlob> targetBlob = CDnnBlob::CreateDataBlob( MathEngine(), CT_Float, 1, 1, 1 ); |
| 222 | + targetBlob->Fill( params.Target ); |
| 223 | + target->SetBlob( targetBlob ); |
| 224 | + |
| 225 | + dnn.RunAndBackwardOnce(); |
| 226 | + |
| 227 | + EXPECT_TRUE( FloatEq( loss->GetLastLoss(), params.LossValue ) ); |
| 228 | +} |
| 229 | + |
| 230 | +INSTANTIATE_TEST_CASE_P( CnnCrossEntropyLossTestInstantiation, CCrossEntropyLossLayerTest, |
| 231 | + ::testing::Values( |
| 232 | + CCrossEntropyLossLayerTestParams( -1.f, -2.f, 0.126928f ), |
| 233 | + CCrossEntropyLossLayerTestParams( -1.f, -0.999f, 0.313531f ), |
| 234 | + CCrossEntropyLossLayerTestParams( -1.f, -1.f, 0.313262f ), |
| 235 | + CCrossEntropyLossLayerTestParams( -1.f, 1.f, 1.313262f ), |
| 236 | + CCrossEntropyLossLayerTestParams( 1.f, 1.f, 0.313262f ), |
| 237 | + CCrossEntropyLossLayerTestParams( 1.f, -1.f, 1.313262f ), |
| 238 | + CCrossEntropyLossLayerTestParams( 1.f, 3.1415926f, 4.2306e-2f ), |
| 239 | + CCrossEntropyLossLayerTestParams( -1.f, -0.451f, 0.49286f ), |
| 240 | + CCrossEntropyLossLayerTestParams( 1.f, 10000.f, 0.f ), |
| 241 | + CCrossEntropyLossLayerTestParams( -1.f, -10000.f, 0.f ) |
| 242 | + ) |
| 243 | +); |
0 commit comments