-
Notifications
You must be signed in to change notification settings - Fork 607
Commit 0dc0578
committed
Update base for Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator
`nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature
```
- func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor
```
## Summary (cases handled)
We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by
- reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and
- [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`.
We also include resizing support for dynamic shapes. Note that only height and width of the input can vary.
## Cases not handled
The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing:
1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon.
2. **Batch (input) N > 1.** Not covered by ATen-VK.
3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK.
## Coming soon
1. Transpose convolution
2. Depthwise convolution (for completeness)
3. Pointwise convolution (for optimization)
4. Null bias
Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/)
[ghstack-poisoned]1 parent 6c50546 commit 0dc0578Copy full SHA for 0dc0578
File tree
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changedFilter options
Expand file treeCollapse file tree
0 file changed
+0
-0
lines changed
0 commit comments