@@ -291,6 +291,101 @@ func _vjpConv3D<Scalar: TensorFlowFloatingPoint>(
291
291
} )
292
292
}
293
293
294
+ /// TensorFlow builtin depthwiseConv2D gradient helper for the input.
295
+ @differentiable ( wrt: ( x, filter) , vjp: _vjpdepthwiseConv2dBackpropInput)
296
+ @usableFromInline
297
+ func depthwiseConv2dBackpropInput< Scalar: TensorFlowFloatingPoint > (
298
+ _ x: Tensor < Scalar > ,
299
+ shape: Tensor < Int32 > ,
300
+ filter: Tensor < Scalar > ,
301
+ strides: ( Int , Int , Int , Int ) ,
302
+ padding: Padding
303
+ ) -> Tensor < Scalar > {
304
+ return Raw . depthwiseConv2dNativeBackpropInput (
305
+ inputSizes: shape,
306
+ filter: filter,
307
+ outBackprop: x,
308
+ strides: [ Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , Int32 ( strides. 2 ) , Int32 ( strides. 3 ) ] ,
309
+ padding: padding. raw)
310
+ }
311
+
312
+ /// TensorFlow builtin depthwiseConv2D gradient helper for the filter.
313
+ @differentiable ( wrt: ( x, input) , vjp: _vjpdepthwiseConv2dBackpropFilter)
314
+ @usableFromInline
315
+ func depthwiseConv2dBackpropFilter< Scalar: TensorFlowFloatingPoint > (
316
+ _ x: Tensor < Scalar > ,
317
+ input: Tensor < Scalar > ,
318
+ filterSizes: Tensor < Int32 > ,
319
+ strides: ( Int , Int , Int , Int ) ,
320
+ padding: Padding
321
+ ) -> Tensor < Scalar > {
322
+ return Raw . depthwiseConv2dNativeBackpropFilter (
323
+ x,
324
+ filterSizes: filterSizes,
325
+ outBackprop: x,
326
+ strides: [ Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , Int32 ( strides. 2 ) , Int32 ( strides. 3 ) ] ,
327
+ padding: padding. raw)
328
+ }
329
+
330
+ @usableFromInline
331
+ func _vjpdepthwiseConv2dBackpropInput< Scalar: TensorFlowFloatingPoint > (
332
+ _ x: Tensor < Scalar > ,
333
+ _ shape: Tensor < Int32 > ,
334
+ _ filter: Tensor < Scalar > ,
335
+ _ strides: ( Int , Int , Int , Int ) ,
336
+ _ padding: Padding
337
+ ) -> ( Tensor < Scalar > , ( Tensor < Scalar > ) -> ( Tensor < Scalar > , Tensor < Scalar > ) ) {
338
+ let value = depthwiseConv2dBackpropInput ( x, shape: shape, filter: filter, strides: strides,
339
+ padding: padding)
340
+ return ( value, { v in
341
+ return (
342
+ depthwiseConv2dBackpropFilter ( x, input: v, filterSizes: shape, strides: strides,
343
+ padding: padding) ,
344
+ depthwiseConv2D ( v, filter: filter, strides: strides, padding: padding)
345
+ )
346
+ } )
347
+ }
348
+
349
+ @usableFromInline
350
+ func _vjpdepthwiseConv2dBackpropFilter< Scalar: TensorFlowFloatingPoint > (
351
+ _ x: Tensor < Scalar > ,
352
+ _ input: Tensor < Scalar > ,
353
+ _ filterSizes: Tensor < Int32 > ,
354
+ _ strides: ( Int , Int , Int , Int ) ,
355
+ _ padding: Padding
356
+ ) -> ( Tensor < Scalar > , ( Tensor < Scalar > ) -> ( Tensor < Scalar > , Tensor < Scalar > ) ) {
357
+ let value = depthwiseConv2dBackpropFilter ( x, input: input, filterSizes: filterSizes,
358
+ strides: strides, padding: padding)
359
+ return ( value, { v in
360
+ return (
361
+ depthwiseConv2dBackpropInput ( x, shape: filterSizes, filter: v, strides: strides,
362
+ padding: padding) ,
363
+ depthwiseConv2D ( input, filter: v, strides: strides, padding: padding)
364
+ )
365
+ } )
366
+ }
367
+
368
+ @usableFromInline
369
+ func _vjpDepthwiseConv2D< Scalar: TensorFlowFloatingPoint > (
370
+ _ x: Tensor < Scalar > ,
371
+ filter: Tensor < Scalar > ,
372
+ strides: ( Int , Int , Int , Int ) ,
373
+ padding: Padding
374
+ ) -> ( Tensor < Scalar > , ( Tensor < Scalar > ) -> ( Tensor < Scalar > , Tensor < Scalar > ) ) {
375
+ let value = depthwiseConv2D ( x, filter: filter, strides: strides,
376
+ padding: padding)
377
+ return ( value, { v in
378
+ return (
379
+ depthwiseConv2dBackpropInput ( v, shape: x. shapeTensor, filter: filter,
380
+ strides: strides, padding: padding
381
+ ) ,
382
+ depthwiseConv2dBackpropFilter ( v, input: x, filterSizes: filter. shapeTensor,
383
+ strides: strides, padding: padding
384
+ )
385
+ )
386
+ } )
387
+ }
388
+
294
389
@usableFromInline
295
390
func _vjpMaxPool2D< Scalar: TensorFlowFloatingPoint > (
296
391
_ x: Tensor < Scalar > ,
@@ -432,6 +527,29 @@ public func conv3D<Scalar: TensorFlowFloatingPoint>(
432
527
padding: padding. raw)
433
528
}
434
529
530
+ /// Computes a 2-D depthwise convolution with the specified input, filter, strides, and padding.
531
+ ///
532
+ /// - Parameters:
533
+ /// - input: The input.
534
+ /// - filter: The depthwise convolution filter.
535
+ /// - strides: The strides of the sliding filter for each dimension of the input.
536
+ /// - padding: The padding for the operation.
537
+ /// - Precondition: `input` must have rank 4.
538
+ /// - Precondition: `filter` must have rank 4.
539
+ @differentiable ( wrt: ( input, filter) , vjp: _vjpDepthwiseConv2D)
540
+ public func depthwiseConv2D< Scalar: TensorFlowFloatingPoint > (
541
+ _ input: Tensor < Scalar > ,
542
+ filter: Tensor < Scalar > ,
543
+ strides: ( Int , Int , Int , Int ) ,
544
+ padding: Padding
545
+ ) -> Tensor < Scalar > {
546
+ return Raw . depthwiseConv2dNative (
547
+ input,
548
+ filter: filter,
549
+ strides: [ Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , Int32 ( strides. 2 ) , Int32 ( strides. 3 ) ] ,
550
+ padding: padding. raw)
551
+ }
552
+
435
553
/// Computes a 2-D max pooling, with the specified filter sizes, strides, and
436
554
/// padding.
437
555
///
0 commit comments