Skip to content

Commit 9d3b714

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned]
2 parents 6319332 + 0dc0578 commit 9d3b714

File tree

6 files changed

+35
-37
lines changed

6 files changed

+35
-37
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ void main() {
8484
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
8585
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
8686
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
87+
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
8788

8889
// To explain the calculation below, the contents of in_texel and the
8990
// group of 4 texels loaded from kernel_in are shown:
@@ -117,17 +118,10 @@ void main() {
117118
//
118119
// which is expressed in the following statements.
119120

120-
const ${VEC4_T[DTYPE]} ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0);
121-
sum = fma(in_texel.xxxx, ktex_0, sum);
122-
123-
const ${VEC4_T[DTYPE]} ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0);
124-
sum = fma(in_texel.yyyy, ktex_1, sum);
125-
126-
const ${VEC4_T[DTYPE]} ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0);
127-
sum = fma(in_texel.zzzz, ktex_2, sum);
128-
129-
const ${VEC4_T[DTYPE]} ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0);
130-
sum = fma(in_texel.wwww, ktex_3, sum);
121+
sum = fma(in_texel.xxxx, texelFetch(kernel_in, ivec2(kxs.x, ky), 0), sum);
122+
sum = fma(in_texel.yyyy, texelFetch(kernel_in, ivec2(kxs.y, ky), 0), sum);
123+
sum = fma(in_texel.zzzz, texelFetch(kernel_in, ivec2(kxs.z, ky), 0), sum);
124+
sum = fma(in_texel.wwww, texelFetch(kernel_in, ivec2(kxs.w, ky), 0), sum);
131125
}
132126
}
133127
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
3232
}
3333
original_sizes;
3434

35-
// Corresponds to {3,3,8,12} in the example below.
35+
// Corresponds to {8,12} in the example below.
3636
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
3737
ivec2 data;
3838
}
@@ -53,7 +53,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5353
* 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
5454
* batches and 1 channel of padding are added, producing a tensor of size
5555
* {12,8,3,3}.
56-
* at::pad(x, {0,0,0,0,0,2,0,1}, "constant", 0);
56+
* at::pad(x, {0,0,0,0,0,1,0,2}, "constant", 0);
5757
*
5858
* 2. Split the tensor along the C dim so that each split has 4 channels.
5959
* x.reshape({12,2,4,3,3});
@@ -94,8 +94,8 @@ void main() {
9494
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);
9595

9696
// Re-map the normal CPU buffer indices to special indices, through a series
97-
// of mappings: reshape is a no-op to the underlying indices, pad is hard, and
98-
// permute is one of the hardest math problems I've ever solved.
97+
// of mappings: reshape is a no-op to the underlying indices, so we only map
98+
// for pad and permute.
9999
const int Np = padded_sizes.data.y;
100100
const int Cp = padded_sizes.data.x;
101101
const int N = original_sizes.data.w;

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void resize_conv2d_node(
4141
*graph,
4242
self.sizes(),
4343
extra_args[0],
44-
/*kernel_only = */ false,
44+
/*kernel_size_only = */ false,
4545
extra_args[1],
4646
extra_args[2],
4747
extra_args[3]);
@@ -56,13 +56,11 @@ ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) {
5656
VK_THROW("aten.convolution.default: Null bias is not supported yet!");
5757
}
5858

59-
TensorRef& tref = graph.get_val(vref).toTensorRef();
60-
ValueRef v = graph.add_tensor(
61-
tref.sizes,
62-
tref.dtype,
59+
ValueRef v = graph.add_tensor_like(
60+
vref,
6361
api::StorageType::TEXTURE_2D,
6462
api::GPUMemoryLayout::TENSOR_WIDTH_PACKED);
65-
vTensor t = graph.get_val(v).toTensor();
63+
vTensor& t = graph.get_val(v).toTensor();
6664

6765
api::ShaderInfo shader = get_nchw_to_image_shader(t);
6866

@@ -110,7 +108,7 @@ ValueRef prepack_weights(ComputeGraph& graph, const ValueRef vref) {
110108
graph.get_val(vref).toTensorRef().dtype,
111109
api::StorageType::TEXTURE_2D,
112110
api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED);
113-
vTensor t = graph.get_val(v).toTensor();
111+
vTensor& t = graph.get_val(v).toTensor();
114112

115113
api::utils::uvec3 global_size = t.extents();
116114
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
@@ -163,7 +161,7 @@ Conv2dParams create_conv2d_params(
163161
});
164162
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
165163
const int32_t in_group_size = api::utils::safe_downcast<int32_t>(
166-
api::utils::align_up(weight_sizes.at(0), INT64_C(4)));
164+
api::utils::align_up(weight_sizes.at(1), INT64_C(4)));
167165
return {overlay_region, in_group_size};
168166
}
169167

@@ -187,21 +185,21 @@ void add_conv2d_node(
187185
const ValueRef dilation,
188186
const ValueRef out) {
189187
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
188+
ValueRef arg_weight = prepack_weights(graph, weight);
189+
ValueRef arg_bias = prepack_biases(graph, bias);
190+
190191
vTensor& t_in = graph.get_val(arg_in).toTensor();
191192
vTensor& t_out = graph.get_val(out).toTensor();
192193

193194
check_conv2d_args(t_in, t_out);
194195

195-
ValueRef arg_weight = prepack_weights(graph, weight);
196-
ValueRef arg_bias = prepack_biases(graph, bias);
197-
198196
api::utils::uvec3 global_size = t_out.virtual_extents();
199197
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
200198

201199
KernelParams kernel_params = create_kernel_params(
202200
graph,
203201
weight,
204-
/*kernel_only = */ false,
202+
/*kernel_size_only = */ false,
205203
stride,
206204
padding,
207205
dilation);

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void resize_max_pool2d_node(
3939
*graph,
4040
self.sizes(),
4141
extra_args[0],
42-
/*kernel_only = */ true,
42+
/*kernel_size_only = */ true,
4343
extra_args[1],
4444
extra_args[2],
4545
extra_args[3],
@@ -83,7 +83,12 @@ void add_max_pool2d_node(
8383
apply_dtype_suffix(kernel_name, t_out);
8484

8585
KernelParams kernel_params = create_kernel_params(
86-
graph, kernel_size, /*kernel_only = */ true, stride, padding, dilation);
86+
graph,
87+
kernel_size,
88+
/*kernel_size_only = */ true,
89+
stride,
90+
padding,
91+
dilation);
8792

8893
graph.execute_nodes().emplace_back(new ExecuteNode(
8994
graph,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) {
1818
api::utils::ivec2 make_ivec2_kernel_size(
1919
ComputeGraph& graph,
2020
const ValueRef weight,
21-
const bool kernel_only) {
22-
if (kernel_only) {
21+
const bool kernel_size_only) {
22+
if (kernel_size_only) {
2323
return make_ivec2_from_list(graph, weight);
2424
} else {
2525
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
@@ -30,12 +30,12 @@ api::utils::ivec2 make_ivec2_kernel_size(
3030
KernelParams create_kernel_params(
3131
ComputeGraph& graph,
3232
const ValueRef weight,
33-
const bool kernel_only,
33+
const bool kernel_size_only,
3434
const ValueRef stride,
3535
const ValueRef padding,
3636
const ValueRef dilation) {
3737
return {
38-
make_ivec2_kernel_size(graph, weight, kernel_only),
38+
make_ivec2_kernel_size(graph, weight, kernel_size_only),
3939
make_ivec2_from_list(graph, stride),
4040
make_ivec2_from_list(graph, padding),
4141
make_ivec2_from_list(graph, dilation),
@@ -63,15 +63,16 @@ std::vector<int64_t> calc_out_sizes_hw(
6363
ComputeGraph& graph,
6464
const std::vector<int64_t>& in_sizes,
6565
const ValueRef weight,
66-
const bool kernel_only,
66+
const bool kernel_size_only,
6767
const ValueRef stride,
6868
const ValueRef padding,
6969
const ValueRef dilation,
7070
const ValueRef ceil_mode) {
7171
const int64_t ndim = in_sizes.size();
7272
std::vector<int64_t> out_sizes(2);
7373

74-
const auto kernel_vec = make_ivec2_kernel_size(graph, weight, kernel_only);
74+
const auto kernel_vec =
75+
make_ivec2_kernel_size(graph, weight, kernel_size_only);
7576
const auto stride_vec = make_ivec2_from_list(graph, stride);
7677
const auto padding_vec = make_ivec2_from_list(graph, padding);
7778
const auto dilation_vec = make_ivec2_from_list(graph, dilation);

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct KernelParams final {
2626
KernelParams create_kernel_params(
2727
ComputeGraph& graph,
2828
const ValueRef weight,
29-
const bool kernel_only,
29+
const bool kernel_size_only,
3030
const ValueRef stride,
3131
const ValueRef padding,
3232
const ValueRef dilation);
@@ -35,7 +35,7 @@ std::vector<int64_t> calc_out_sizes_hw(
3535
ComputeGraph& graph,
3636
const std::vector<int64_t>& in_sizes,
3737
const ValueRef weight,
38-
const bool kernel_only,
38+
const bool kernel_size_only,
3939
const ValueRef stride,
4040
const ValueRef padding,
4141
const ValueRef dilation,

0 commit comments

Comments
 (0)