Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit e465331

Browse files
Shashi456rxwei
authored andcommitted
Fix in vjpConv2DBackpropFilter (#397)
1 parent 33fe7f3 commit e465331

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

Sources/TensorFlow/Operators/NN.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,9 @@ func _vjpConv2DBackpropFilter<Scalar: TensorFlowFloatingPoint>(
189189
let value = conv2DBackpropFilter(x, input: input, filterSizes: filterSizes,
190190
strides: strides, padding: padding, dilations: dilations)
191191
return (value, { v in
192-
(conv2DBackpropInput(x, shape: filterSizes, filter: v, strides: strides,
193-
padding: padding, dilations: dilations),
194-
conv2D(input, filter: v, strides: strides, padding: padding, dilations: dilations))
192+
(conv2D(input, filter: v, strides: strides, padding: padding, dilations: dilations),
193+
conv2DBackpropInput(x, shape: x.shapeTensor, filter: v, strides: strides,
194+
padding: padding, dilations: dilations))
195195
})
196196
}
197197

0 commit comments

Comments
 (0)