Skip to content

Commit 1f1a2c2

Browse files
committed
Update base for Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon For our CUNET model, the first two are required and the third is useful. 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned]
1 parent 3f27222 commit 1f1a2c2

File tree

1 file changed

+0
-0
lines changed

1 file changed

+0
-0
lines changed

0 commit comments

Comments
 (0)