Skip to content

Support multiple inputs in CDropoutLayer #991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions NeoML/Python/neoml/Dnn/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,25 @@ class Dropout(Layer):

.. rubric:: Layer inputs:

(1) a data blob of any dimensions.
The layer can have any number of inputs.

.. rubric:: Layer outputs:

(1) a blob of the same dimensions, with some of the elements set to 0,
during training only.
When you run the network, this layer does nothing.
The layer returns one output for each input. Output blob size matches the size of corresponding input blob.
"""

def __init__(self, input_layer, rate=0.5, spatial=False, batchwise=False, name=None):
def __init__(self, input_layers, rate=0.5, spatial=False, batchwise=False, name=None):

if type(input_layer) is PythonWrapper.Dropout:
super().__init__(input_layer)
if type(input_layers) is PythonWrapper.Dropout:
super().__init__(input_layers)
return

layers, outputs = check_input_layers(input_layer, 1)
layers, outputs = check_input_layers(input_layers, 0)

if rate < 0 or rate >= 1:
raise ValueError('The `rate` must be in [0, 1).')

internal = PythonWrapper.Dropout(str(name), layers[0], int(outputs[0]), float(rate), bool(spatial), bool(batchwise))
internal = PythonWrapper.Dropout(str(name), layers, outputs, float(rate), bool(spatial), bool(batchwise))
super().__init__(internal)

@property
Expand Down
10 changes: 6 additions & 4 deletions NeoML/Python/src/PyDropoutLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ void InitializeDropoutLayer( py::module& m )
{
return new CPyDropoutLayer( *layer.Layer<CDropoutLayer>(), layer.MathEngineOwner() );
}))
.def( py::init([]( const std::string& name, const CPyLayer& layer, int outputNumber, float dropoutRate,
.def( py::init([]( const std::string& name, const py::list& layers, const py::list& outputs, float dropoutRate,
bool isSpatial, bool isBatchwise )
{
py::gil_scoped_release release;
CDnn& dnn = layer.Dnn();
CDnn& dnn = layers[0].cast<CPyLayer>().Dnn();
IMathEngine& mathEngine = dnn.GetMathEngine();
CPtr<CDropoutLayer> dropout = new CDropoutLayer( mathEngine );
dropout->SetDropoutRate( dropoutRate );
dropout->SetSpatial( isSpatial );
dropout->SetBatchwise( isBatchwise );
dropout->SetName( FindFreeLayerName( dnn, "Dropout", name ).c_str() );
dnn.AddLayer( *dropout );
dropout->Connect( 0, layer.BaseLayer(), outputNumber );
return new CPyDropoutLayer( *dropout, layer.MathEngineOwner() );
for( int i = 0; i < layers.size(); ++i ) {
dropout->Connect( i, layers[i].cast<CPyLayer>().BaseLayer(), outputs[i].cast<int>() );
}
return new CPyDropoutLayer( *dropout, layers[0].cast<CPyLayer>().MathEngineOwner() );
}) )
.def( "get_rate", &CPyDropoutLayer::GetDropoutRate, py::return_value_policy::reference )
.def( "set_rate", &CPyDropoutLayer::SetDropoutRate, py::return_value_policy::reference )
Expand Down
19 changes: 12 additions & 7 deletions NeoML/Python/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,18 +628,23 @@ def test_dotproduct(self):
def test_dropout(self):
math_engine = neoml.MathEngine.CpuMathEngine()
dnn = neoml.Dnn.Dnn(math_engine)
source = neoml.Dnn.Source(dnn, "source")
dropout = neoml.Dnn.Dropout(source, 0.5, True, True, "dropout")
sink = neoml.Dnn.Sink(dropout, "sink")
source0 = neoml.Dnn.Source(dnn, "source0")
source1 = neoml.Dnn.Source(dnn, "source1")
dropout = neoml.Dnn.Dropout([source0, source1], 0.5, True, True, "dropout")
sink0 = neoml.Dnn.Sink(dropout, "sink0")
sink1 = neoml.Dnn.Sink((dropout, 1), "sink1")
layer = dnn.layers['dropout']
self.assertEqual(layer.name, 'dropout')

input = neoml.Blob.asblob(math_engine, np.ones((2, 3, 5, 4), dtype=np.float32), (2, 3, 1, 5, 1, 1, 4))
inputs = {"source": input}
input0 = neoml.Blob.asblob(math_engine, np.ones((2, 3, 5, 4), dtype=np.float32), (2, 3, 1, 5, 1, 1, 4))
input1 = neoml.Blob.asblob(math_engine, np.ones((7, 2, 4, 3), dtype=np.float32), (1, 7, 2, 1, 4, 3, 1))
inputs = {'source0': input0, 'source1': input1}
outputs = dnn.run(inputs)
a = outputs["sink"].asarray()
a = outputs["sink0"].asarray()
b = outputs["sink1"].asarray()

self.assertEqual(a.shape, input.asarray().shape)
self.assertEqual(a.shape, input0.asarray().shape)
self.assertEqual(b.shape, input1.asarray().shape)
self.assertEqual(dropout.rate, 0.5)
self.assertEqual(dropout.spatial, True)
self.assertEqual(dropout.batchwise, True)
Expand Down
4 changes: 2 additions & 2 deletions NeoML/docs/en/API/NN/DropoutLayer.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ The layer has no trainable parameters.

## Inputs

The single input accepts a data blob of arbitrary size.
Each input accepts a data blob of arbitrary size.

## Outputs

The single output returns a blob of the same size with some of the elements set to `0`. Note that this will happen only during training; when you are running the network without training no elements are dropped out.
Each output returns a blob of corresponding input's size with some of the elements set to `0`. Note that this will happen only during training; when you are running the network without training no elements are dropped out.
4 changes: 2 additions & 2 deletions NeoML/docs/ru/API/NN/DropoutLayer.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void SetBatchwise( bool value );

## Входы

На единственный вход подается блоб с данными произвольного размера.
На каждый вход подается блоб с данными произвольного размера.

## Выходы

Единственный выход содержит блоб того же размера. Если сеть обучается, то часть элементов будут занулены.
Соответствующие выходы содержит блоб тех же размеров. Если сеть обучается, то часть элементов будут занулены.
8 changes: 3 additions & 5 deletions NeoML/include/NeoML/Dnn/Layers/DropoutLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@ class NEOML_API CDropoutLayer : public CBaseInPlaceLayer {
void SetBatchwise( bool value );

protected:
~CDropoutLayer() override { destroyDropoutDesc(); }

// CBaseLayer methods
void RunOnce() override;
void BackwardOnce() override;
void OnReshaped() override;
int BlobsForBackward() const override { return 0; }

private:
CDropoutDesc* desc; // the dropout description
CPointerArray<CDropoutDesc> descs; // the dropout descriptions
float dropoutRate; // the dropout rate
bool isSpatial; // the spatial mode (channel-wise)
bool isBatchwise; // the batchwise mode

void initDropoutDesc();
void destroyDropoutDesc();
void initDropoutDescs();
void destroyDropoutDescs();
};

NEOML_API CLayerWrapper<CDropoutLayer> Dropout( float dropoutRate,
Expand Down
53 changes: 29 additions & 24 deletions NeoML/src/Dnn/Layers/DropoutLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace NeoML {

CDropoutLayer::CDropoutLayer( IMathEngine& mathEngine ) :
CBaseInPlaceLayer( mathEngine, "CCnnDropoutLayer" ),
desc( 0 ),
dropoutRate( 0 ),
isSpatial( false ),
isBatchwise( false )
Expand All @@ -41,7 +40,7 @@ void CDropoutLayer::Serialize( CArchive& archive )
archive.Serialize( isBatchwise );

if( archive.IsLoading() ) {
destroyDropoutDesc();
destroyDropoutDescs();
}
}

Expand All @@ -51,7 +50,7 @@ void CDropoutLayer::SetDropoutRate( float value )
if( dropoutRate != value ) {
dropoutRate = value;
if( GetDnn() != 0 ) {
destroyDropoutDesc();
destroyDropoutDescs();
}
}
}
Expand All @@ -61,7 +60,7 @@ void CDropoutLayer::SetSpatial( bool value )
if( value != isSpatial ) {
isSpatial = value;
if( GetDnn() != 0 ) {
destroyDropoutDesc();
destroyDropoutDescs();
}
}
}
Expand All @@ -71,57 +70,63 @@ void CDropoutLayer::SetBatchwise( bool value )
if( value != isBatchwise ) {
isBatchwise = value;
if( GetDnn() != 0 ) {
destroyDropoutDesc();
destroyDropoutDescs();
}
}
}

void CDropoutLayer::OnReshaped()
{
destroyDropoutDesc();
destroyDropoutDescs();
}

void CDropoutLayer::RunOnce()
{
CheckInput1();

if( !IsBackwardPerformed() ) {
MathEngine().VectorCopy( outputBlobs[0]->GetData(), inputBlobs[0]->GetData(),
inputBlobs[0]->GetDataSize() );
for( int i = 0; i < inputBlobs.Size(); ++i ) {
MathEngine().VectorCopy( outputBlobs[i]->GetData(), inputBlobs[i]->GetData(),
inputBlobs[i]->GetDataSize() );
}
return;
}

initDropoutDesc();
initDropoutDescs();

MathEngine().Dropout( *desc, inputBlobs[0]->GetData(), outputBlobs[0]->GetData() );
for( int i = 0; i < inputBlobs.Size(); ++i ) {
MathEngine().Dropout( *descs[i], inputBlobs[0]->GetData(), outputBlobs[0]->GetData());
}
}

void CDropoutLayer::BackwardOnce()
{
// Backward pass is only possible when learning
NeoAssert( desc != 0 );
for( int i = 0; i < outputDiffBlobs.Size(); ++i ) {
// Backward pass is only possible when learning
NeoAssert( descs[i] != 0 );

MathEngine().Dropout( *desc, outputDiffBlobs[0]->GetData(), inputDiffBlobs[0]->GetData() );
MathEngine().Dropout( *descs[i], outputDiffBlobs[i]->GetData(), inputDiffBlobs[i]->GetData());
}

if( !GetDnn()->IsRecurrentMode() || GetDnn()->IsFirstSequencePos() ) {
// Clear the memory after the whole sequence is processed
destroyDropoutDesc();
destroyDropoutDescs();
}
}

void CDropoutLayer::initDropoutDesc()
void CDropoutLayer::initDropoutDescs()
{
if( desc == 0 ) {
desc = MathEngine().InitDropout( dropoutRate, isSpatial, isBatchwise, inputBlobs[0]->GetDesc(), outputBlobs[0]->GetDesc(),
GetDnn()->Random().Next() );
descs.SetSize( inputBlobs.Size() );
for( int i = 0; i < descs.Size(); ++i ) {
if( descs[i] == nullptr ) {
descs.ReplaceAt( MathEngine().InitDropout( dropoutRate, isSpatial, isBatchwise, inputBlobs[0]->GetDesc(),
outputBlobs[0]->GetDesc(), GetDnn()->Random().Next() ), i );
}
}
}

void CDropoutLayer::destroyDropoutDesc()
void CDropoutLayer::destroyDropoutDescs()
{
if( desc != 0 ) {
delete desc;
desc = 0;
for( int i = 0; i < descs.Size(); ++i ) {
descs.ReplaceAt( nullptr, i );
}
}

Expand Down