Skip to content

Commit 6319332

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (SlidingWindow)"
## The Operator `nn.Module` invocations of [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) and [`nn.ConvTranspose2d`](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html#torch.nn.ConvTranspose2d) get compiled to `aten.convolution.default` in the Edge Dialect, which carries the signature ``` - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, SymInt[] padding, int[] dilation, bool transposed, SymInt[] output_padding, int groups) -> Tensor ``` ## Summary (cases handled) We introduce support for the convolution cases covered by [ATen-VK's default SlidingWindow implementation](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L73). This is achieved by - reusing the [existing `conv2d.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d.glsl), and - [moving special weights prepacking from CPU](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/ops/Convolution.cpp#L134-L235) to the GPU in `conv2d_prepack_weights.glsl`. We also include resizing support for dynamic shapes. Note that only height and width of the input can vary. ## Cases not handled The implementation is on-par with ATen-VK's SlidingWindow. This means the following cases are missing: 1. **Groups G > 1.** Largely not covered by ATen-VK. `G = in_channels` is covered by ATen-VK's Depthwise impl and will be added soon. 2. **Batch (input) N > 1.** Not covered by ATen-VK. 3. **Padding > 0 while Dilation, Kernel > 1.** Not covered by ATen-VK. ## Coming soon 1. Transpose convolution 2. Depthwise convolution (for completeness) 3. Pointwise convolution (for optimization) 4. Null bias Differential Revision: [D55346778](https://our.internmc.facebook.com/intern/diff/D55346778/) [ghstack-poisoned]
2 parents 1cd8c02 + 6c50546 commit 6319332

File tree

14 files changed

+301
-199
lines changed

14 files changed

+301
-199
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,27 @@ ValueRef ComputeGraph::add_tensor(
132132
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
133133
}
134134

135+
ValueRef ComputeGraph::add_tensor_like(
136+
const ValueRef vref,
137+
const api::StorageType storage_type,
138+
const api::GPUMemoryLayout memory_layout) {
139+
TensorRef& tref = get_val(vref).toTensorRef();
140+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141+
}
142+
143+
ValueRef ComputeGraph::add_tensor_like(
144+
const ValueRef vref,
145+
const api::GPUMemoryLayout memory_layout) {
146+
TensorRef& tref = get_val(vref).toTensorRef();
147+
return add_tensor(tref.sizes, tref.dtype, memory_layout);
148+
}
149+
135150
ValueRef ComputeGraph::add_tensor(
136151
const std::vector<int64_t>& sizes,
137152
const api::ScalarType dtype,
138153
const int64_t shared_object_idx) {
139154
return add_tensor(
140-
sizes,
141-
dtype,
142-
suggested_storage_type(),
143-
suggested_memory_layout(sizes),
144-
shared_object_idx);
155+
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
145156
}
146157

147158
ValueRef ComputeGraph::add_tensorref(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,25 @@ class ComputeGraph final {
191191
*/
192192
ValueRef add_tensor(
193193
const std::vector<int64_t>& sizes,
194-
const api::ScalarType dtype = api::ScalarType::Float,
194+
const api::ScalarType dtype,
195195
const int64_t shared_object_idx = -1);
196196

197+
/*
198+
* Add a `vTensor` value to the graph with the properties of `vref`.
199+
*/
200+
ValueRef add_tensor_like(
201+
const ValueRef vref,
202+
const api::StorageType storage_type,
203+
const api::GPUMemoryLayout memory_layout);
204+
205+
/*
206+
* Add a `vTensor` value to the graph with the properties of `vref`. The
207+
* suggested storage type will be used to construct the `vTensor`.
208+
*/
209+
ValueRef add_tensor_like(
210+
const ValueRef vref,
211+
const api::GPUMemoryLayout memory_layout);
212+
197213
/*
198214
* Add a `TensorRef` value to the graph with the specific properties. A
199215
* `TensorRef` is a reference to a `vTensor` whose data is stored in an

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
3636
api::Context* const context = graph->context();
3737
api::PipelineBarrier pipeline_barrier{};
3838

39-
TensorRef tref = graph->get_val(tref_).toTensorRef();
40-
vTensor packed = graph->get_val(packed_).toTensor();
39+
TensorRef& tref = graph->get_val(tref_).toTensorRef();
40+
vTensor& packed = graph->get_val(packed_).toTensor();
4141

4242
size_t numel = api::utils::multiply_integers(tref.sizes);
4343
api::StorageBuffer staging(graph->context(), tref.dtype, numel);

backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
3333
original_sizes;
3434

3535
// Corresponds to {3,3,8,12} in the example below.
36-
layout(set = 0, binding = 4) uniform PRECISION restrict AlignedSizes {
37-
ivec4 data;
36+
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
37+
ivec2 data;
3838
}
3939
padded_sizes;
4040

@@ -94,39 +94,31 @@ void main() {
9494
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);
9595

9696
// Re-map the normal CPU buffer indices to special indices, through a series
97-
// of permutations: reshape is a no-op to the underlying indices, and permute
98-
// is one of the hardest math problems I've ever solved.
99-
//
97+
// of mappings: reshape is a no-op to the underlying indices, pad is hard, and
98+
// permute is one of the hardest math problems I've ever solved.
99+
const int Np = padded_sizes.data.y;
100+
const int Cp = padded_sizes.data.x;
101+
const int N = original_sizes.data.w;
102+
const int C = original_sizes.data.z;
103+
const int H = original_sizes.data.y;
104+
const int W = original_sizes.data.x;
105+
100106
// Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
101107
// Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
102108
// Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w)
103109
// Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w)
104-
const ivec4 p1 = SWAP_DIMS(
105-
p0,
106-
4,
107-
(padded_sizes.data.w / 4),
108-
(padded_sizes.data.y * padded_sizes.data.z * padded_sizes.data.x));
109-
const ivec4 p2 = SWAP_DIMS(
110-
p1,
111-
padded_sizes.data.y,
112-
(padded_sizes.data.z / 4),
113-
(padded_sizes.data.x * 4));
114-
const ivec4 p3 = SWAP_DIMS(p2, padded_sizes.data.x, 4, 1);
115-
const ivec4 p4 = SWAP_DIMS(p3, padded_sizes.data.y, 4, padded_sizes.data.x);
110+
const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Np / 4), (H * Cp * W));
111+
const ivec4 p2 = SWAP_ADJ_DIMS(p1, H, (Cp / 4), (W * 4));
112+
const ivec4 p3 = SWAP_ADJ_DIMS(p2, W, 4, 1);
113+
const ivec4 p4 = SWAP_ADJ_DIMS(p3, H, 4, W);
116114

117-
// For values in the padded region, write zero instead of buffer data.
118-
//
119115
// Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
120-
const ivec4 c = p4 %
121-
(padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x) /
122-
(padded_sizes.data.y * padded_sizes.data.x);
123-
const ivec4 n =
124-
p4 / (padded_sizes.data.z * padded_sizes.data.y * padded_sizes.data.x);
125-
const ivec4 p5 = p4 -
126-
n * (padded_sizes.data.z - original_sizes.data.z) * padded_sizes.data.y *
127-
padded_sizes.data.x;
128-
const ivec4 mask = ivec4(greaterThanEqual(c, original_sizes.data.zzzz)) |
129-
ivec4(greaterThanEqual(n, original_sizes.data.wwww));
116+
// For values in the padded region, write zero instead of buffer data.
117+
const ivec4 c = p4 % (Cp * H * W) / (H * W);
118+
const ivec4 n = p4 / (Cp * H * W);
119+
const ivec4 p5 = p4 - n * (Cp - C) * H * W;
120+
const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) |
121+
ivec4(greaterThanEqual(n, ivec4(N)));
130122

131123
${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x);
132124
${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y);

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@
4646
#define STRIDE_HEIGHT_PACKED(vec) (vec.x)
4747

4848
// Given a buffer(1-D) index cur, compute a new index where the corresponding
49-
// tensor(N-D)'s x and y dimensions are swapped, and size is of the M-D plane of
50-
// dimensions lower than x and y.
51-
#define SWAP_DIMS(cur, x, y, size) \
52-
cur + \
53-
size*( \
54-
(1 - y) * ((cur % (x * y * size)) / (y * size)) + \
55-
(x - 1) * ((cur % (y * size)) / size))
49+
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
50+
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
51+
// tensor of shape {4,3,2,24} to obtain {3,4,2,24}. Then, x=4, y=3 and
52+
// plane=2*24=48.
53+
#define SWAP_ADJ_DIMS(cur, x, y, plane) \
54+
cur + \
55+
plane*( \
56+
(1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
57+
(x - 1) * ((cur % (y * plane)) / plane))

0 commit comments

Comments
 (0)