Skip to content

[NeoML] Layers mem-optimize #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions NeoML/include/NeoML/Dnn/Layers/BatchNormalizationLayer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,14 @@ namespace NeoML {
class NEOML_API CBatchNormalizationLayer : public CBaseLayer {
NEOML_DNN_LAYER( CBatchNormalizationLayer )
public:
// The training parameters names
enum TParamName {
PN_Gamma,
PN_Beta,

PN_Count
};

explicit CBatchNormalizationLayer( IMathEngine& mathEngine );

void Serialize( CArchive& archive ) override;
Expand All @@ -36,21 +44,21 @@ class NEOML_API CBatchNormalizationLayer : public CBaseLayer {
// If true, "channel-based" statistics is gathered, that is, the data for each channel is averaged across the other dimensions
// If false, the statistics is averaged across the batch (BatchLength * BatchWidth * ListSize dimensions)
bool IsChannelBased() const { return isChannelBased; }
void SetChannelBased(bool _isChannelBased);
void SetChannelBased( bool _isChannelBased );

// Convergence rate for slow statistics (gathered across several batches)
// This value may be from (0; 1] interval
// The smaller this value, the more statistics takes previous data into account (~ 1 / rate)
float GetSlowConvergenceRate() const { return slowConvergenceRate->GetData().GetValue(); }
void SetSlowConvergenceRate(float rate);
void SetSlowConvergenceRate( float rate );

// The final normalization parameters
CPtr<CDnnBlob> GetFinalParams() { updateFinalParams(); return finalParams == 0 ? 0 : finalParams->GetCopy(); }
void SetFinalParams(const CPtr<CDnnBlob>& _params);
void SetFinalParams( const CPtr<CDnnBlob>& _params );

// Indicates if the free term should be set to zero ("no bias")
bool IsZeroFreeTerm() const { return isZeroFreeTerm; }
void SetZeroFreeTerm(bool _isZeroFreeTerm) { isZeroFreeTerm = _isZeroFreeTerm; }
void SetZeroFreeTerm( bool _isZeroFreeTerm ) { isZeroFreeTerm = _isZeroFreeTerm; }

// Indicates if the final params weights should be used for initialization
// After initialization the value will be reset to false automatically
Expand All @@ -67,29 +75,6 @@ class NEOML_API CBatchNormalizationLayer : public CBaseLayer {
int BlobsForLearn() const override { return 0; }

private:
bool isChannelBased;
bool isZeroFreeTerm; // indicates if the free term is zero
CPtr<CDnnBlob> slowConvergenceRate; // the convergence rate for slow statistics
CPtr<CDnnBlob> finalParams; // the final linear operation parameters (gamma, beta)

// The variables used to calculate statistics
CPtr<CDnnBlob> varianceEpsilon;
CPtr<CDnnBlob> fullBatchInv;
CPtr<CDnnBlob> varianceNorm;
CPtr<CDnnBlob> residual;

CPtr<CDnnBlob> normalized;

CPtr<CDnnBlob> varianceMult;

// The training parameters names
enum TParamName {
PN_Gamma = 0, // gamma
PN_Beta, // beta

PN_Count,
};

// Internal (untrainable) parameters
enum TInternalParamName {
IPN_Average = 0, // the average across the batch
Expand All @@ -98,28 +83,39 @@ class NEOML_API CBatchNormalizationLayer : public CBaseLayer {
IPN_SlowAverage, // the average across several batches
IPN_SlowVariance, // the variance estimate across several batches

IPN_Count,
IPN_Count
};

CPtr<CDnnBlob> finalParams; // the final linear operation parameters (gamma, beta)
CPtr<CDnnBlob> internalParams;
CPtr<CDnnBlob> normalized;

CPtr<CDnnBlob> slowConvergenceRate; // the convergence rate for slow statistics
// The variables used to calculate statistics
CPtr<CDnnBlob> varianceEpsilon;
CPtr<CDnnBlob> fullBatchInv;
CPtr<CDnnBlob> varianceNorm;
CPtr<CDnnBlob> residual;
CPtr<CDnnBlob> varianceMult;

bool useFinalParamsForInitialization; // indicates if final params should be used for initialization
bool isChannelBased = true;
bool isZeroFreeTerm = false; // indicates if the free term is zero
bool isFinalParamDirty = false; // indicates if final params need updating
bool useFinalParamsForInitialization = false; // indicates if final params should be used for initialization

bool checkAndCreateParams();
void getFullBatchAndObjectSize(int& fullBatchSize, int& objectSize);
bool checkAndCreateParams( const CFloatHandle& temp );
void getFullBatchAndObjectSize( int& fullBatchSize, int& objectSize );
void runWhenLearning();
void runWhenNoLearning();
void processInput(const CPtr<CDnnBlob>& inputBlob, const CPtr<CDnnBlob>& paramBlob);
void processInput( const CPtr<CDnnBlob>& inputBlob, const CPtr<CDnnBlob>& paramBlob );
void calculateAverage();
void calculateVariance();
void calculateVariance( const CFloatHandle& temp );
void calculateNormalized();
void updateSlowParams(bool isInit);
void updateSlowParams( bool isInit );
void backwardWhenLearning();
void backwardWhenNoLearning();

bool isFinalParamDirty; // indicates if final params need updating
void updateFinalParams();

void initializeFromFinalParams();
void initializeFromFinalParams( const CFloatHandle& ones );
};

NEOML_API CLayerWrapper<CBatchNormalizationLayer> BatchNormalization(
Expand Down
6 changes: 3 additions & 3 deletions NeoML/include/NeoML/Dnn/Layers/CenterLossLayer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,15 +59,15 @@ class NEOML_API CCenterLossLayer : public CLossLayer {

private:
// The number of classes
int numberOfClasses;
int numberOfClasses = 0;
// The centers convergence rate
CPtr<CDnnBlob> classCentersConvergenceRate;
// The unit multiplier
CPtr<CDnnBlob> oneMult;
// The internal blobs
CPtr<CDnnBlob> classCentersBlob;

void updateCenters(const CFloatHandle& tempDiffHandle);
void updateCenters( const CConstFloatHandle& tempDiff, const CConstIntHandle& labels, int batchSize, int vectorSize );
};

NEOML_API CLayerWrapper<CCenterLossLayer> CenterLoss( int numberOfClasses,
Expand Down
12 changes: 5 additions & 7 deletions NeoML/include/NeoML/Dnn/Layers/FocalLossLayer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,28 +45,26 @@ class NEOML_API CFocalLossLayer : public CLossLayer {

protected:
void Reshape() override;
virtual void BatchCalculateLossAndGradient( int batchSize, CConstFloatHandle data, int vectorSize, CConstFloatHandle label,
void BatchCalculateLossAndGradient( int batchSize, CConstFloatHandle data, int vectorSize, CConstFloatHandle label,
int labelSize, CFloatHandle lossValue, CFloatHandle lossGradient ) override;

private:
// The gamma parameter from the paper
// Specifies the degree to which learning will concentrate on difficult-to-distinguish objects
CPtr<CDnnBlob> focalForce;

// -1
CPtr<CDnnBlob> minusOne;

// The handle for acceptable minimum and maximum probability values (so that separation can be performed correctly)
CPtr<CDnnBlob> minProbValue;
CPtr<CDnnBlob> maxProbValue;

// Calculates the function gradient
void calculateGradient( CFloatHandle correctClassProbabilityPerBatchHandle, int batchSize, int labelSize,
CFloatHandle remainderVectorHandle, CFloatHandle entropyPerBatchHandle, CFloatHandle tempMatrixHandle,
void calculateGradient( CFloatHandle correctClassProbabilityPerBatch, int batchSize, int labelSize,
CFloatHandle remainderVector, CConstFloatHandle entropyPerBatch, CConstFloatHandle remainderPowered,
CConstFloatHandle label, CFloatHandle lossGradient );
};

NEOML_API CLayerWrapper<CFocalLossLayer> FocalLoss(
float focalForce = CFocalLossLayer::DefaultFocalForceValue, float lossWeight = 1.0f );
float focalForce = CFocalLossLayer::DefaultFocalForceValue, float lossWeight = 1.0f );

} // namespace NeoML
46 changes: 24 additions & 22 deletions NeoML/include/NeoML/Dnn/Layers/LossLayer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,15 +52,16 @@ class NEOML_API CLossLayer : public CBaseLayer {
// is averaged across the BatchSize dimension and returned as the result
float Test( int batchSize, CConstFloatHandle data, int vectorSize, CConstFloatHandle label, int labelSize,
CConstFloatHandle dataDelta );
float Test(int batchSize, CConstFloatHandle data, int vectorSize, CConstIntHandle label, int labelSize,
float Test( int batchSize, CConstFloatHandle data, int vectorSize, CConstIntHandle label, int labelSize,
CConstFloatHandle dataDelta );

// Tests the layer performance on the basis of data, labels, and dataDelta generated by a uniform random distribution
// labels and data are of the same size: batchSize * vectorSize
float TestRandom(CRandom& random, int batchSize, float dataLabelMin, float dataLabelMax, float deltaAbsMax, int vectorSize);
float TestRandom( CRandom& random, int batchSize, float dataLabelMin, float dataLabelMax, float deltaAbsMax,
int vectorSize, bool oneHot = false );
// Similar to the previous method, but with labels generated as Int [0; labelMax), with size 1
float TestRandom(CRandom& random, int batchSize, float dataMin, float dataMax, int labelMax, float deltaAbsMax,
int vectorSize);
float TestRandom( CRandom& random, int batchSize, float dataMin, float dataMax, int labelMax, float deltaAbsMax,
int vectorSize );

protected:
const CPtr<CDnnBlob>& GetWeights() { return weights; }
Expand Down Expand Up @@ -102,11 +103,12 @@ class NEOML_API CLossLayer : public CBaseLayer {
CObjectArray<CDnnBlob> lossGradientBlobs;

template<class T>
float testImpl(int batchSize, CConstFloatHandle data, int vectorSize, CTypedMemoryHandle<const T> label, int labelSize,
CConstFloatHandle dataDelta);
float testImpl(int batchSize, CConstFloatHandle data, int vectorSize, CTypedMemoryHandle<const T> label,
int labelSize, CConstFloatHandle dataDelta);
};

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CCrossEntropyLossLayer implements a layer that calculates the loss value as cross-entropy between the result and the standard
// By default, softmax function is additionally applied to the results
/*
Expand Down Expand Up @@ -140,7 +142,7 @@ class NEOML_API CLossLayer : public CBaseLayer {
class NEOML_API CCrossEntropyLossLayer : public CLossLayer {
NEOML_DNN_LAYER( CCrossEntropyLossLayer )
public:
explicit CCrossEntropyLossLayer( IMathEngine& mathEngine );
explicit CCrossEntropyLossLayer( IMathEngine& mathEngine ) : CLossLayer( mathEngine, "CCnnCrossEntropyLossLayer" ) {}

// Indicates if softmax function should be applied to input data. True by default.
// If you turn off the flag, make sure each vector you pass to the input contains only positive numbers making 1 in total.
Expand All @@ -158,26 +160,27 @@ class NEOML_API CCrossEntropyLossLayer : public CLossLayer {
int labelSize, CFloatHandle lossValue, CFloatHandle lossGradient) override;

private:
bool isSoftmaxApplied;
bool isSoftmaxApplied = true;
};

NEOML_API CLayerWrapper<CCrossEntropyLossLayer> CrossEntropyLoss(
bool isSoftmaxApplied = true, float lossWeight = 1.0f );

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CBinaryCrossEntropyLossLayer is a binary variant of cross-entropy
// It takes non-normalized probabilities of +1 class of size BatchSize x 1 as the first input (network response)
// and blob of the same size with values -1.f / +1.f as the second input (labels)
class NEOML_API CBinaryCrossEntropyLossLayer : public CLossLayer {
NEOML_DNN_LAYER( CBinaryCrossEntropyLossLayer )
public:
explicit CBinaryCrossEntropyLossLayer( IMathEngine& mathEngine );
explicit CBinaryCrossEntropyLossLayer( IMathEngine& mathEngine ) :
CLossLayer( mathEngine, "CCnnBinaryCrossEntropyLossLayer" ) {}

// The weight for the positive side of the sigmoid
// Values over 1 increase recall, values below 1 increase precision
void SetPositiveWeight( float value );
float GetPositiveWeight() const;
void SetPositiveWeight( float value ) { positiveWeightMinusOne = value - 1; }
float GetPositiveWeight() const { return positiveWeightMinusOne + 1; }

void Serialize( CArchive& archive ) override;

Expand All @@ -189,15 +192,15 @@ class NEOML_API CBinaryCrossEntropyLossLayer : public CLossLayer {

private:
// constants used for calculating the function value
float positiveWeightMinusOneValue;
float positiveWeightMinusOne = 0;

void calculateStableSigmoid( const CConstFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) const;
void calculateStableSigmoid( const CFloatHandle& firstHandle, const CFloatHandle& resultHandle, int vectorSize ) const;
};

NEOML_API CLayerWrapper<CBinaryCrossEntropyLossLayer> BinaryCrossEntropyLoss(
float positiveWeight = 1.0f, float lossWeight = 1.0f );

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CEuclideanLossLayer implements a layer that calculates the loss function
// equal to the sum of squared differences between the result and the standard
Expand All @@ -218,7 +221,7 @@ class NEOML_API CEuclideanLossLayer : public CLossLayer {

NEOML_API CLayerWrapper<CEuclideanLossLayer> EuclideanLoss( float lossWeight = 1.0f );

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CL1LossLayer implements a layer that estimates the loss value as abs(result - standard)
// The layer has two inputs: #0 - result, #1 - standard
Expand All @@ -237,7 +240,7 @@ class NEOML_API CL1LossLayer : public CLossLayer {

NEOML_API CLayerWrapper<CL1LossLayer> L1Loss( float lossWeight = 1.0f );

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CHingeLossLayer implements a layer that estimates the loss value as max(0, 1 - result * standard)
// The layer has two inputs: #0 - result, #1 - standard
Expand All @@ -257,7 +260,7 @@ class NEOML_API CHingeLossLayer : public CLossLayer {

NEOML_API CLayerWrapper<CHingeLossLayer> HingeLoss( float lossWeight = 1.0f );

///////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------

// CSquaredHingeLossLayer implements a layer that estimates the loss value as max(0, 1 - result * standard)**2
// The layer has two inputs: #0 - result, #1 - standard
Expand All @@ -276,7 +279,6 @@ class NEOML_API CSquaredHingeLossLayer : public CLossLayer {
int labelSize, CFloatHandle lossValue, CFloatHandle lossGradient) override;
};

NEOML_API CLayerWrapper<CSquaredHingeLossLayer> SquaredHingeLoss(
float lossWeight = 1.0f );
NEOML_API CLayerWrapper<CSquaredHingeLossLayer> SquaredHingeLoss( float lossWeight = 1.0f );

} // namespace NeoML
20 changes: 10 additions & 10 deletions NeoML/include/NeoML/Dnn/Layers/PrecisionRecallLayer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,23 +33,23 @@ class NEOML_API CPrecisionRecallLayer : public CQualityControlLayer {

// Retrieves the result over the last batch as a 4-number array:
// true positives, positives total, true negatives, negatives total
void GetLastResult( CArray<int>& results );
void GetLastResult( CArray<int>& results ) const;

protected:
void Reshape() override;
void OnReset() override;
void RunOnceAfterReset() override;

virtual int& PositivesTotal(){ return positivesTotal; };
virtual int& NegativesTotal(){ return negativesTotal; };
virtual int& PositivesCorrect(){ return positivesCorrect; };
virtual int& NegativesCorrect(){ return negativesCorrect; };
virtual int PositivesTotal() const { return accumulated->GetData<int>().GetValueAt( TP_PositivesTotal ); }
virtual int NegativesTotal() const { return accumulated->GetData<int>().GetValueAt( TP_NegativesTotal ); }
virtual int PositivesCorrect() const { return accumulated->GetData<int>().GetValueAt( TP_PositivesCorrect ); }
virtual int NegativesCorrect() const { return accumulated->GetData<int>().GetValueAt( TP_NegativesCorrect ); }

private:
int positivesTotal;
int negativesTotal;
int positivesCorrect;
int negativesCorrect;
enum { TP_PositivesCorrect, TP_PositivesTotal, TP_NegativesCorrect, TP_NegativesTotal, TP_Count };
// Store consts in device memory to avoid excess syncs
CPtr<CDnnBlob> accumulated;
CPtr<CDnnBlob> current;
};

NEOML_API CLayerWrapper<CPrecisionRecallLayer> PrecisionRecall();
Expand Down
Loading
Loading