Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Fix Transposed Conv2d error & add test #288

Merged
merged 20 commits into from
Nov 8, 2019
Merged

Conversation

Shashi456
Copy link
Contributor

@Shashi456 Shashi456 commented Jun 24, 2019

Fix: #282
This is a pretty long error post.
The transposed conv2d layer doesn't work. There are a number of issues that have popped up.

        let w = (input.shape[1] - (1 * paddingIndex)) *
          strides.0 + (filter.shape[0] * paddingIndex)
        let h = (input.shape[2] - (1 * paddingIndex)) *
          strides.1 + (filter.shape[1] * paddingIndex)

The first of errors i found was in the logic, Keras updates its own new dimensions by calculating

 assert padding in {'same', 'valid', 'full'}
    if dim_size is None:
        return None

    # Get the dilated kernel size
    kernel_size = (kernel_size - 1) * dilation + 1

    # Infer length if output padding is None, else compute the exact length
    if output_padding is None:
        if padding == 'valid':
            dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0)
        elif padding == 'full':
            dim_size = dim_size * stride_size - (stride_size + kernel_size - 2)
        elif padding == 'same':
            dim_size = dim_size * stride_size
    else:
        if padding == 'same':
            pad = kernel_size // 2
        elif padding == 'valid':
            pad = 0
        elif padding == 'full':
            pad = kernel_size - 1

        dim_size = ((dim_size - 1) * stride_size + kernel_size - 2 * pad +
                    output_padding)

return dim_size

So i changed the above code to reflect this.

        let w = (input.shape[1] - 1) * 
          strides.0 + (filter.shape[0] * paddingIndex)
        let h = (input.shape[2] - 1) *
          strides.1 + (filter.shape[1] * paddingIndex)

Then what i found out was, that

Running this test in tensorflow produces the output as mentioned :

from functools import reduce
import operator
import tensorflow as tf

def product(iterable):
    return reduce(operator.mul, iterable, 1)

# Returns a tensor with increasing scalar values starting from zero,
# with the given shape.
def iota(shape):
    x = tf.range(0, product(shape), dtype=tf.float32)
    return tf.reshape(x, shape)

input_shape = [1,4,2,1]
filter_shape = [4,2,1,1]
strides = (1, 1, 1, 1)
input = iota(input_shape)
filter = iota(filter_shape)
conv3d = tf.nn.conv2d_transpose(input, filter, output_shape=(1,4,2,1), strides=strides, padding='SAME')
conv3 = tf.nn.bias_add(conv3d, [8])


with tf.Session() as session:
  result = session.run(conv3)
  print(result.shape)
  # (1, 4, 2, 1)
  print(result)
#   [[[[  8.]
#    [ 12.]]

#   [[ 12.]
#    [ 28.]]

#   [[ 24.]
#    [ 64.]]

#   [[ 48.]
#    [112.]]]]

But when i ran it in swift with .same padding we get the outputs

("[[[[ 93.0],
   [ 52.0]],

  [[148.0],
   [ 76.0]],

  [[ 93.0],
   [ 46.0]],

  [[ 46.0],
   [ 22.0]]]]")

If there's something i've inherently come to understand its that, The tensorflow conv2d_transpose is just a wrapper which is called by keras. You can check the code for it here

So running the function directly might be a cheap hack. The keras transposed conv2d is the place we actually compute the new dimensions, the code for it can be found here.
and the backend for that here.

I'm very confused by the code right now. But i think there's some place i'm committing a major mistake while trying to translate the code. Would appreciate some help, sorry for the very long introductory message.

@Shashi456
Copy link
Contributor Author

We dont currently support output paddings, so the logic falls to the default value which is zero, also we dont support full padding, that would possibly need control flow.

@jekbradbury
Copy link
Contributor

This LGTM assuming the test case comes from a comparison with an existing framework (e.g. Keras).

@Shashi456
Copy link
Contributor Author

@jekbradbury, surprisingly doesn't work though

@jekbradbury
Copy link
Contributor

Ohh, as in the test doesn't pass?

@Shashi456
Copy link
Contributor Author

Yea they don't.

@jekbradbury
Copy link
Contributor

Does it give [22, 46, 46, 93, 76, 148, 52, 93] instead?

@Shashi456
Copy link
Contributor Author

Shashi456 commented Jun 25, 2019

@jekbradbury it gives

("[[[[ 93.0],
   [ 52.0]],

  [[148.0],
   [ 76.0]],

  [[ 93.0],
   [ 46.0]],

  [[ 46.0],
   [ 22.0]]]]

@jekbradbury
Copy link
Contributor

OK, so what's happening is that some frameworks define transposed convolution in different (but almost equivalent) ways, and you can switch between these definitions by reflecting the filter over both spatial axes and swapping the input and output channel dimensions. (The reason for this is that one of these definitions is the standard mathematical definition for a transposed convolution, and the other allows the framework to reuse the same kernels as the normal convolution backwards pass).

I'm not quite sure how you're getting the numbers you just pasted, but they're the correct result for the "standard mathematical definition" (although flipped) and the expected numbers in the test are the correct result for the "backwards pass of convolution" definition that's used by TF/Keras. I'm confused as to why the test currently includes Conv2D<Float>; I assume you mean TransposedConv2D<Float>.

@jekbradbury
Copy link
Contributor

I'm particularly confused because conv2dBackpropInput should give the "backwards pass of convolution" result (which is exactly what it sounds like) and it's also what Keras uses.

@Shashi456
Copy link
Contributor Author

@jekbradbury I've not seen any direct usage of the Tensorflow transposed conv2d though, It's mostly the keras api which is used which then in return ultimately calls the Tensorflow API. What I'm unsure about is, give the inputs I have, if I manually compute the output shape it's different from the output I'm getting.

Also since you mentioned it, would you suggest any changes for the current order of parameters that we use?

@Shashi456
Copy link
Contributor Author

Shashi456 commented Jun 25, 2019

the current error for the test is now :

Fatal error: Conv2DCustomBackpropInput: Size of out_backprop doesn't match computed: 
actual = 4, computed = 3 spatial_dim: 1 input: 3 filter: 4 output: 4 stride: 1 dilation: 1: 

I think the test i wrote is wrong, because the code i wrote directly calls tf.nn.conv2d_tranpose which just calls backprop2dinput , while its in the keras api where the dimensions of the output are actually calculated

@Shashi456
Copy link
Contributor Author

@jekbradbury Do you have any idea as to how we could solve the latest error?

@marcrasi
Copy link
Contributor

I haven't carefully read this whole discussion, but https://bugs.swift.org/browse/TF-540 and this thread might be related: https://groups.google.com/a/tensorflow.org/forum/#!msg/swift/UUPwV01sZrE/LszG6T7dBQAJ ?

@Shashi456
Copy link
Contributor Author

@marcrasi that was an extension error, there seems to be some error within the layer implementation as well or it's the test. Just trying to figure it out currently. That thread did help me earlier to figure out a few errors :) thanks.

@sjaz24
Copy link
Contributor

sjaz24 commented Jul 1, 2019

If you revert back to the original code, that is, put back in "- (1 * paddingIndex)" instead of always subtracting 1, then your test should work. It works for me. However, I get an error if I attempt to get gradients. Not sure if that should work or not??

    let filter = Tensor(shape: [4, 2, 1, 1], 
                        scalars: (0..<8).map(Float.init))
    let bias = Tensor<Float>([8])
    let layer = TransposedConv2D(filter: filter, 
                                 bias: bias, 
                                 activation: identity,
                                 strides: (1, 1), 
                                 padding: .same)
    let input = Tensor(shape: [1, 4, 2, 1], 
                       scalars: (0..<8).map(Float.init))
    let output = layer.inferring(from: input)
    let expected = Tensor<Float>(shape: [1, 4, 2, 1],
                                 scalars: [8, 12, 12, 28, 24, 64, 48, 112])
    print(output == expected)
    /* this outputs true. it works until this point */
   
    /* this fails */
    let (loss, grads) = layer.valueWithGradient { layer -> Tensor<Float> in 
        return layer(input).sum()
    }
    /* the following error occurs:
       Fatal error: Conv2DCustomBackpropFilter: input depth must be evenly divisible by filter depth: file /Users/danielzheng/swift-tf/tensorflow-swift-apis/Sources/TensorFlow/Bindings/EagerExecution.swift, line 299 Illegal instruction: 4
     */

@sjaz24
Copy link
Contributor

sjaz24 commented Jul 1, 2019

Also, not sure if this is an issue or not but the TransposedConv2D defines tensors as width then height whereas the underlying Raw.conv2DBackpropInput defines tensors as height then width for example from TransposedConv2D:

filter: A 4-D tensor of shape
     `[width, height, input channel count, output channel count]

But in func conv2DBackpropInput:

filter: 4-D with shape
    `[filter_height, filter_width, in_channels, out_channels]

@Shashi456
Copy link
Contributor Author

Yeah @sjaz24 I noticed that we might have documented it wrong

@Shashi456
Copy link
Contributor Author

@sjaz24 Are you sure this test passes locally for you?
Because i get this error,

Fatal error: Conv2DCustomBackpropInput: Size of out_backprop doesn't match computed: actual = 4, computed = 0 spatial_dim: 1 input: 3 filter: 4 output: 4 stride: 1 dilation: 1: file /swift-base/tensorflow-swift-apis/Sources/TensorFlow/Bindings/EagerExecution.swift, line 299
Current stack trace:
0    libswiftCore.so                    0x00007fec9e1b88d0 swift_reportError + 50
1    libswiftCore.so                    0x00007fec9e227ac0 _swift_stdlib_reportFatalErrorInFile + 115
2    libswiftCore.so                    0x00007fec9e14faee <unavailable> + 3738350
3    libswiftCore.so                    0x00007fec9e14fc67 <unavailable> + 3738727
4    libswiftCore.so                    0x00007fec9df1dc4d <unavailable> + 1436749
5    libswiftCore.so                    0x00007fec9e124a98 <unavailable> + 3562136
6    libswiftCore.so                    0x00007fec9df1d0a9 <unavailable> + 1433769
7    libswiftTensorFlow.so              0x00007fec9b45dc80 <unavailable> + 2669696
8    libswiftTensorFlow.so              0x00007fec9b2c2e00 checkOk(_:file:line:) + 461
9    libswiftTensorFlow.so              0x00007fec9b2c9f30 TFE_Op.evaluateUnsafe() + 506
10   libswiftTensorFlow.so              0x00007fec9b2ca7a0 TFE_Op.execute<A>(_:) + 132
11   libswiftTensorFlow.so              0x00007fec9b2d3434 <unavailable> + 1053748
16   libswiftTensorFlow.so              0x00007fec9b4a872b <unavailable> + 2975531
17   libswiftTensorFlow.so              0x00007fec9b59550c <unavailable> + 3945740
18   libswiftTensorFlow.so              0x00007fec9b42eba0 withContext<A>(_:_:) + 143
19   libswiftTensorFlow.so              0x00007fec9b42ed20 withLearningPhase<A>(_:_:) + 234
20   libswiftTensorFlow.so              0x00007fec9b4a85a0 Layer.inferring(from:) + 232
22   repl_swift                         0x0000000000400490 <unavailable> + 1168

@sjaz24
Copy link
Contributor

sjaz24 commented Jul 30, 2019

Yes, it works. I just basically ran the same code in Colab.

Screen Shot 2019-07-29 at 9 22 03 PM

@sjaz24
Copy link
Contributor

sjaz24 commented Jul 30, 2019

You most likely haven't undone your changes. You need to put back the original code that you changed.

(1 * paddingIndex))

@t-ae
Copy link
Contributor

t-ae commented Aug 2, 2019

@sjaz24 pointed out about filter shape order in documentation.
It looks there's another mistake, input channel count and output channel count.

Basically conv2DBackpropInput is for backpropagating Conv2D.
So what conv2DBackpropInput calls in_channels is Conv2D's input channel count, not TransposedConv2D's.
TransposedConv2D's filter shape is transposition of Conv2D's filter shape. So input channel count and output channel count must be swapped.

It's what Marc Rasi says here:
https://groups.google.com/a/tensorflow.org/forum/m/#!msg/swift/UUPwV01sZrE/LszG6T7dBQAJ

In summary, the documentation of filter should be:

filter: A 4-D tensor of shape
     `[height, width, output channel count, input channel count]

@Shashi456
Copy link
Contributor Author

Fixed this, Ready to be reviewed. Test and build pass locally.

@Shashi456 Shashi456 changed the title Fixing transposed conv2d error Fix Transposed Conv2d error & add test Aug 31, 2019
@saeta
Copy link
Contributor

saeta commented Nov 7, 2019

Hi @Shashi456! Are you able to fix the merge conflicts here? Thanks! -Brennan

@saeta saeta requested a review from marcrasi November 7, 2019 18:48
@Shashi456
Copy link
Contributor Author

@saeta done.

@marcrasi marcrasi merged commit 35dfddf into tensorflow:master Nov 8, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in TransposedConv2d
8 participants