Skip to content

Commit 7b26d3d

Browse files
committed
[NeoML] MultiheadAttentionPerformerLayer
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 236c0e7 commit 7b26d3d

File tree

8 files changed

+1332
-0
lines changed

8 files changed

+1332
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright © 2023 ABBYY
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
--------------------------------------------------------------------------------------------------------------*/
15+
16+
#pragma once
17+
18+
#include <NeoML/Dnn/Dnn.h>
19+
20+
namespace NeoML {
21+
22+
struct CFavorAttentionDesc;
23+
24+
// Computes FAVOR normalized self-attention.
25+
// https://arxiv.org/pdf/2009.14794.pdf.
26+
//
27+
// Inputs: query, key, value
28+
// Emulates equation: Output ~~ softmax( query * ( key )^T / normalizer ) * value
29+
//
30+
// output
31+
// ^
32+
// |
33+
// +---------------+
34+
// | F A V O R | <-- projection_matrix
35+
// | Attention |
36+
// +---------------+
37+
// ^ ^ ^
38+
// | | |
39+
// query key value
40+
//
41+
class NEOML_API CFavorAttentionPerformerLayer : public CBaseLayer {
42+
NEOML_DNN_LAYER( CFavorAttentionPerformerLayer )
43+
public:
44+
// Possible activation kernel transformations
45+
enum class TAKernel { SoftMax = 0, ReLU = 1 };
46+
// Layer inputs
47+
enum TInput { TI_Q = 0, TI_K = 1, TI_V = 2 };
48+
49+
CFavorAttentionPerformerLayer( IMathEngine& mathEngine, const char* name = nullptr );
50+
51+
int GetRandomFeaturesCount() const { return randomFeaturesCount; }
52+
void SetRandomFeaturesCount( int randomFeaturesCount );
53+
54+
int GetActivationKernel() const { return static_cast<int>( activation ); }
55+
void SetActivationKernel( int activation );
56+
57+
bool GetCausal() const { return causal; }
58+
void SetCausal( bool causal );
59+
60+
bool GetProjectionMatrixType() const { return projectionMatrixType; }
61+
void SetProjectionMatrixType( bool projectionMatrixType );
62+
63+
void Serialize( CArchive& archive ) override;
64+
65+
protected:
66+
~CFavorAttentionPerformerLayer();
67+
68+
// Create output blobs using the input blobs
69+
void Reshape() override;
70+
// One step of a forward pass
71+
void RunOnce() override;
72+
// One step of a backward pass
73+
void BackwardOnce() override;
74+
75+
private:
76+
int randomFeaturesCount = 1; // Number of random features to be used
77+
TAKernel activation = TAKernel::SoftMax; // Activation Kernel
78+
bool causal = false; // Auto-regressive or not
79+
bool projectionMatrixType = true; // Either random projection matrix will be applied (for SoftMax should be true)
80+
CFavorAttentionDesc* desc = nullptr; // Favor Attention desctiption
81+
82+
void destroyFavorAttentionDesc();
83+
};
84+
85+
NEOML_API CLayerWrapper<CFavorAttentionPerformerLayer> FavorAttentionPerformer(
86+
int randomFeaturesCount, int activation, bool causal, bool projectionMatrixType );
87+
88+
} // namespace NeoML
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/* Copyright © 2023 ABBYY
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
--------------------------------------------------------------------------------------------------------------*/
15+
16+
#pragma once
17+
18+
#include <NeoML/Dnn/Dnn.h>
19+
#include <NeoML/Dnn/Layers/CompositeLayer.h>
20+
#include <NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h>
21+
22+
namespace NeoML {
23+
24+
// Multihead Self Attention Performer
25+
// https://arxiv.org/pdf/2009.14794.pdf
26+
// Implementation of multiheaded FAVOR-attention & FAVOR-self-attention layers.
27+
//
28+
// +----------------------+--------+-------------------------------------------------------
29+
// | Parameter | Type | Description
30+
// +----------------------+--------+-------------------------------------------------------
31+
// | HiddenSize | int | output dim of hidden layer
32+
// | HeadCount | int | number of heads to repeat the same attention structure
33+
// | OutputSize | int |
34+
// | DropoutRate (TODO) | float | dropout rate inside attention for training
35+
// +----------------------+--------+-------------------------------------------------------
36+
class NEOML_API CMultiheadAttentionPerformerLayer : public CCompositeLayer {
37+
NEOML_DNN_LAYER( CMultiheadAttentionPerformerLayer )
38+
public:
39+
explicit CMultiheadAttentionPerformerLayer( IMathEngine& mathEngine );
40+
41+
// Set the Activation Kernel: SoftMax(=0), ReLU(=1)
42+
// By default is SoftMax
43+
int GetActivationKernel() const { return favor->GetActivationKernel(); }
44+
void SetActivationKernel( int activationKernel )
45+
{ favor->SetActivationKernel( activationKernel ); }
46+
47+
// The number of heads in attention
48+
// The GetHiddenSize() must be a multiple of this value
49+
// By default attention consist of 1 head
50+
int GetHeadCount() const { return headCount; }
51+
void SetHeadCount( int headCount );
52+
53+
// The size of trainable matrices
54+
// Must be a multiple of GetHeadCount()
55+
int GetHiddenSize() const { return hiddenSize; }
56+
void SetHiddenSize( int hiddenSize );
57+
58+
// The size of output
59+
int GetOutputSize() const { return outputSize; }
60+
void SetOutputSize( int outputSize );
61+
62+
void Serialize( CArchive& archive ) override;
63+
64+
// Recreates the layer if forceRebuild is true or it doesn't contain sublayers
65+
void Rebuild( bool forceRebuild );
66+
67+
protected:
68+
void Reshape() override;
69+
70+
private:
71+
// The amount of heads
72+
int headCount;
73+
// The size of the trainable matrix
74+
int hiddenSize;
75+
// Output size
76+
int outputSize;
77+
78+
CPtr<CFavorAttentionPerformerLayer> favor;
79+
80+
// Layer inputs
81+
enum TInputs {
82+
I_Q = 0,
83+
I_K = 1,
84+
I_V = 2,
85+
I_Mask = 3
86+
};
87+
88+
// Layer outputs
89+
enum TOutputs {
90+
O_Output = 0,
91+
O_Softmax = 1
92+
};
93+
94+
bool isCreated() const { return HasLayer( "Q" ); }
95+
void create();
96+
97+
CBaseLayer* multiplyInputByMatrixWeights( int size, const char* name, TInputs input );
98+
CBaseLayer* multiplyByMatrixWeights( CBaseLayer* input, int width, const char* prefix );
99+
CBaseLayer* prepareQ( CBaseLayer* input );
100+
CBaseLayer* prepareKV( CBaseLayer* input );
101+
CBaseLayer* prepareOutput( CBaseLayer* input );
102+
};
103+
104+
NEOML_API CLayerWrapper<CMultiheadAttentionPerformerLayer> MultiheadAttentionPerformer(
105+
int headCount, int hiddenSize, int outputSize );
106+
107+
} // namespace NeoML

NeoML/include/NeoML/NeoML.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ limitations under the License.
115115
#include <NeoML/Dnn/Layers/DepthToSpaceLayer.h>
116116
#include <NeoML/Dnn/Layers/DotProductLayer.h>
117117
#include <NeoML/Dnn/Layers/EnumBinarizationLayer.h>
118+
#include <NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h>
118119
#include <NeoML/Dnn/Layers/FocalLossLayer.h>
119120
#include <NeoML/Dnn/Layers/FullyConnectedSourceLayer.h>
120121
#include <NeoML/Dnn/Layers/GlobalMaxPoolingLayer.h>
@@ -130,6 +131,7 @@ limitations under the License.
130131
#include <NeoML/Dnn/Layers/LrnLayer.h>
131132
#include <NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h>
132133
#include <NeoML/Dnn/Layers/ModelWrapperLayer.h>
134+
#include <NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h>
133135
#include <NeoML/Dnn/Layers/MultiHingeLossLayer.h>
134136
#include <NeoML/Dnn/Layers/PositionalEmbeddingLayer.h>
135137
#include <NeoML/Dnn/Layers/PrecisionRecallLayer.h>

NeoML/src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ set(NeoML_SOURCES
117117
Dnn/Layers/DotProductLayer.cpp
118118
Dnn/Layers/EnumBinarizationLayer.cpp
119119
Dnn/Layers/FocalLossLayer.cpp
120+
Dnn/Layers/FavorAttentionPerformerLayer.cpp
120121
Dnn/Layers/FullyConnectedSourceLayer.cpp
121122
Dnn/Layers/GlobalMaxPoolingLayer.cpp
122123
Dnn/Layers/GlobalSumPoolingLayer.cpp
@@ -132,6 +133,7 @@ set(NeoML_SOURCES
132133
Dnn/Layers/MaxOverTimePoolingLayer.cpp
133134
Dnn/Layers/MobileNetV3BlockLayer.cpp
134135
Dnn/Layers/ModelWrapperLayer.cpp
136+
Dnn/Layers/MultiheadAttentionPerformerLayer.cpp
135137
Dnn/Layers/ObjectNormalizationLayer.cpp
136138
Dnn/Layers/Onnx/OnnxEltwiseLayer.cpp
137139
Dnn/Layers/Onnx/OnnxCastLayer.cpp
@@ -377,6 +379,7 @@ set(NeoML_HEADERS
377379
../include/NeoML/Dnn/Layers/DotProductLayer.h
378380
../include/NeoML/Dnn/Layers/EnumBinarizationLayer.h
379381
../include/NeoML/Dnn/Layers/FocalLossLayer.h
382+
../include/NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h
380383
../include/NeoML/Dnn/Layers/FullyConnectedSourceLayer.h
381384
../include/NeoML/Dnn/Layers/GlobalMaxPoolingLayer.h
382385
../include/NeoML/Dnn/Layers/GlobalSumPoolingLayer.h
@@ -392,6 +395,7 @@ set(NeoML_HEADERS
392395
../include/NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h
393396
../include/NeoML/Dnn/Layers/MobileNetV3BlockLayer.h
394397
../include/NeoML/Dnn/Layers/ModelWrapperLayer.h
398+
../include/NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h
395399
../include/NeoML/Dnn/Layers/MultiHingeLossLayer.h
396400
../include/NeoML/Dnn/Layers/ObjectNormalizationLayer.h
397401
../include/NeoML/Dnn/Layers/Onnx/OnnxEltwiseLayer.h

NeoML/src/Dnn/Dnn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ limitations under the License.
7272
#include <NeoML/Dnn/Layers/DepthToSpaceLayer.h>
7373
#include <NeoML/Dnn/Layers/DotProductLayer.h>
7474
#include <NeoML/Dnn/Layers/EnumBinarizationLayer.h>
75+
#include <NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h>
7576
#include <NeoML/Dnn/Layers/FocalLossLayer.h>
7677
#include <NeoML/Dnn/Layers/FullyConnectedSourceLayer.h>
7778
#include <NeoML/Dnn/Layers/GlobalMaxPoolingLayer.h>
@@ -88,6 +89,7 @@ limitations under the License.
8889
#include <NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h>
8990
#include <NeoML/Dnn/Layers/MobileNetV3BlockLayer.h>
9091
#include <NeoML/Dnn/Layers/ModelWrapperLayer.h>
92+
#include <NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h>
9193
#include <NeoML/Dnn/Layers/MultiHingeLossLayer.h>
9294
#include <NeoML/Dnn/Layers/PositionalEmbeddingLayer.h>
9395
#include <NeoML/Dnn/Layers/PrecisionRecallLayer.h>
@@ -349,6 +351,7 @@ REGISTER_NEOML_LAYER( CCtcDecodingLayer, "FmlCnnCtcDecodingLayer" )
349351
REGISTER_NEOML_LAYER( CCtcLossLayer, "FmlCnnCtcLossLayer" )
350352
REGISTER_NEOML_LAYER( CDotProductLayer, "FmlCnnDotProductLayer" )
351353
REGISTER_NEOML_LAYER( CEnumBinarizationLayer, "FmlCnnEnumBinarizationLayer" )
354+
REGISTER_NEOML_LAYER( CFavorAttentionPerformerLayer, "NeoMLDnnFavorAttentionPerformerLayer" )
352355
REGISTER_NEOML_LAYER( CGlobalMaxPoolingLayer, "FmlCnnGlobalMaxPoolingLayer" )
353356
REGISTER_NEOML_LAYER( CGrnLayer, "NeoMLDnnGrnLayer" )
354357
REGISTER_NEOML_LAYER( CGruLayer, "FmlCnnGruLayer" )
@@ -360,6 +363,7 @@ REGISTER_NEOML_LAYER( CLoraFullyConnectedLayer, "NeoMLDnnLoraFullyConnectedLayer
360363
REGISTER_NEOML_LAYER( CMaxOverTimePoolingLayer, "FmlCnnMaxOverTimePoolingLayer" )
361364
REGISTER_NEOML_LAYER( CMobileNetV3PreSEBlockLayer, "NeoMLDnnMobileNetV3PreSEBlockLayer" )
362365
REGISTER_NEOML_LAYER( CMobileNetV3PostSEBlockLayer, "NeoMLDnnMobileNetV3PostSEBlockLayer" )
366+
REGISTER_NEOML_LAYER( CMultiheadAttentionPerformerLayer, "NeoMLDnnMultiheadAttentionPerformerLayer" )
363367
REGISTER_NEOML_LAYER( CMultiHingeLossLayer, "FmlCnnMultyHingeLossLayer" )
364368
REGISTER_NEOML_LAYER( CMultiSquaredHingeLossLayer, "FmlCnnMultySquaredHingeLossLayer" )
365369
REGISTER_NEOML_LAYER( CPixelToImageLayer, "FmlCnnPixelToImageLayerClass" )

0 commit comments

Comments
 (0)