Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8e1a71d

Browse files
sjaz24rxwei
authored andcommitted
Fix 'conv2D' derivative (#331)
`_vjpConv2DBackpropInput` should use the filter shape for `conv2DBackpropFilter`'s `filterSizes` parameter.
1 parent b698f1f commit 8e1a71d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Sources/TensorFlow/Operators/NN.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func _vjpConv2DBackpropInput<Scalar: TensorFlowFloatingPoint>(
185185
let value = conv2DBackpropInput(x, shape: shape, filter: filter,
186186
strides: strides, padding: padding, dilations: dilations)
187187
return (value, { v in
188-
(conv2DBackpropFilter(x, input: v, filterSizes: shape, strides: strides,
188+
(conv2DBackpropFilter(x, input: v, filterSizes: filter.shapeTensor, strides: strides,
189189
padding: padding, dilations: dilations),
190190
conv2D(v, filter: filter, strides: strides, padding: padding, dilations: dilations))
191191
})

0 commit comments

Comments
 (0)