@@ -34,7 +34,7 @@ public func identity<Scalar>(_ x: Tensor<Scalar>) -> Tensor<Scalar> {
34
34
public extension Tensor where Scalar: TensorFlowFloatingPoint {
35
35
// TODO: Verify that these calculations are correct.
36
36
@inlinable
37
- func _vjpBatchNormalized(
37
+ internal func _vjpBatchNormalized(
38
38
alongAxis axis: Int32 ,
39
39
offset: Tensor ,
40
40
scale: Tensor ,
@@ -120,93 +120,79 @@ public extension Padding {
120
120
}
121
121
}
122
122
123
- extension Tensor where Scalar: TensorFlowFloatingPoint {
123
+ public extension Tensor where Scalar: TensorFlowFloatingPoint {
124
124
/// TensorFlow builtin conv2d gradient helper for the input.
125
125
@inlinable
126
- @differentiable (
127
- wrt: ( filter, backpropOutput) ,
128
- vjp: _vjpTFConv2DBackpropInput ( _: _: _: _: _: )
129
- )
130
- func _TFConv2DBackpropInput(
126
+ @differentiable ( wrt: ( self , filter) , vjp: _vjpConv2DBackpropInput)
127
+ internal func conv2DBackpropInput(
131
128
shape: Tensor < Int32 > ,
132
129
filter: Tensor ,
133
- backpropOutput: Tensor ,
134
130
strides: ( Int32 , Int32 , Int32 , Int32 ) ,
135
131
padding: Padding
136
132
) -> Tensor {
137
133
return Raw . conv2DBackpropInput (
138
134
inputSizes: shape,
139
135
filter: filter,
140
- outBackprop: backpropOutput ,
136
+ outBackprop: self ,
141
137
strides: [ strides. 0 , strides. 1 , strides. 2 , strides. 3 ] ,
142
138
padding: padding. raw)
143
139
}
144
140
145
- /// TensorFlow builtin conv2d gradient helper for the filter.
141
+ /// TensorFlow builtin conv2d gradient helper for the filter.
146
142
@inlinable
147
- @differentiable (
148
- wrt: ( input, backpropOutput) ,
149
- vjp: _vjpTFConv2DBackpropFilter ( _: _: _: _: _: )
150
- )
151
- func _TFConv2DBackpropFilter(
143
+ @differentiable ( wrt: ( self , input) , vjp: _vjpConv2DBackpropFilter)
144
+ internal func conv2DBackpropFilter(
152
145
input: Tensor ,
153
146
filterSizes: Tensor < Int32 > ,
154
- backpropOutput: Tensor ,
155
147
strides: ( Int32 , Int32 , Int32 , Int32 ) ,
156
148
padding: Padding
157
149
) -> Tensor {
158
150
return Raw . conv2DBackpropFilter (
159
151
input,
160
152
filterSizes: filterSizes,
161
- outBackprop: backpropOutput ,
153
+ outBackprop: self ,
162
154
strides: [ strides. 0 , strides. 1 , strides. 2 , strides. 3 ] ,
163
155
padding: padding. raw)
164
156
}
165
157
166
158
@inlinable
167
- func _vjpTFConv2DBackpropInput (
159
+ internal func _vjpConv2DBackpropInput (
168
160
_ shape: Tensor < Int32 > ,
169
161
_ filter: Tensor ,
170
- _ backpropOutput: Tensor ,
171
162
_ strides: ( Int32 , Int32 , Int32 , Int32 ) ,
172
163
_ padding: Padding
173
164
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
174
- let value = _TFConv2DBackpropInput ( shape: shape, filter: filter,
175
- backpropOutput: backpropOutput,
176
- strides: strides, padding: padding)
165
+ let value = conv2DBackpropInput ( shape: shape, filter: filter, strides: strides,
166
+ padding: padding)
177
167
return ( value, { v in
178
168
return (
179
- self . _TFConv2DBackpropFilter ( input: v, filterSizes: shape,
180
- backpropOutput: backpropOutput,
181
- strides: strides, padding: padding) ,
169
+ self . conv2DBackpropFilter ( input: v, filterSizes: shape, strides: strides,
170
+ padding: padding) ,
182
171
v. convolved2D ( withFilter: filter, strides: strides, padding: padding)
183
172
)
184
173
} )
185
174
}
186
175
187
176
@inlinable
188
- func _vjpTFConv2DBackpropFilter (
177
+ internal func _vjpConv2DBackpropFilter (
189
178
_ input: Tensor ,
190
179
_ filterSizes: Tensor < Int32 > ,
191
- _ backpropOutput: Tensor ,
192
180
_ strides: ( Int32 , Int32 , Int32 , Int32 ) ,
193
181
_ padding: Padding
194
182
) -> ( Tensor , ( Tensor ) -> ( Tensor , Tensor ) ) {
195
- let value = _TFConv2DBackpropFilter ( input: input, filterSizes: filterSizes,
196
- backpropOutput: backpropOutput,
197
- strides: strides, padding: padding)
183
+ let value = conv2DBackpropFilter ( input: input, filterSizes: filterSizes,
184
+ strides: strides, padding: padding)
198
185
return ( value, { v in
199
186
return (
200
- self . _TFConv2DBackpropInput ( shape: filterSizes, filter: v,
201
- backpropOutput: backpropOutput,
202
- strides: strides, padding: padding) ,
187
+ self . conv2DBackpropInput ( shape: filterSizes, filter: v, strides: strides,
188
+ padding: padding) ,
203
189
input. convolved2D ( withFilter: v, strides: strides, padding: padding)
204
190
)
205
191
} )
206
192
}
207
193
208
194
@inlinable
209
- func _vjpConvolved2D(
195
+ internal func _vjpConvolved2D(
210
196
filter: Tensor ,
211
197
strides: ( Int32 , Int32 , Int32 , Int32 ) ,
212
198
padding: Padding
@@ -215,20 +201,20 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
215
201
padding: padding)
216
202
return ( value, { v in
217
203
return (
218
- self . _TFConv2DBackpropInput (
219
- shape: self . shapeTensor, filter: filter, backpropOutput : v ,
204
+ v . conv2DBackpropInput (
205
+ shape: self . shapeTensor, filter: filter,
220
206
strides: strides, padding: padding
221
207
) ,
222
- self . _TFConv2DBackpropFilter (
223
- input: self , filterSizes: filter. shapeTensor, backpropOutput : v ,
208
+ v . conv2DBackpropFilter (
209
+ input: self , filterSizes: filter. shapeTensor,
224
210
strides: strides, padding: padding
225
211
)
226
212
)
227
213
} )
228
214
}
229
215
230
216
@inlinable
231
- func _vjpMaxPooled(
217
+ internal func _vjpMaxPooled(
232
218
kernelSize: ( Int32 , Int32 , Int32 , Int32 ) ,
233
219
strides: ( Int32 , Int32 , Int32 , Int32 ) ,
234
220
padding: Padding
@@ -250,7 +236,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
250
236
}
251
237
252
238
@inlinable
253
- func _vjpAveragePooled(
239
+ internal func _vjpAveragePooled(
254
240
kernelSize: ( Int32 , Int32 , Int32 , Int32 ) ,
255
241
strides: ( Int32 , Int32 , Int32 , Int32 ) ,
256
242
padding: Padding
@@ -284,8 +270,8 @@ public extension Tensor where Scalar: FloatingPoint {
284
270
/// - Precondition: `filter` must have rank 4.
285
271
@inlinable @inline ( __always)
286
272
@differentiable (
287
- wrt: ( self , filter) , vjp: _vjpConvolved2D ( filter : strides : padding : )
288
- where Scalar : TensorFlowFloatingPoint
273
+ wrt: ( self , filter) , vjp: _vjpConvolved2D
274
+ where Scalar: TensorFlowFloatingPoint
289
275
)
290
276
func convolved2D(
291
277
withFilter filter: Tensor ,
0 commit comments