Skip to content

Commit 6118012

Browse files
committed
Update on "[ET-VK][Ez] Introduce convenience constexpr for StorageTypes and GPUMemoryLayouts"
## Context Introduce the following convenience `constexpr`: * `api::kBuffer`, `api::kTexture3D`, and `api::kTexture2D` * `api::kWidthPacked`, `api::kHeightPacked`, and `api::kChannelsPacked` Also remove the `api::StorageType::UNKNOWN` enum entry as it doesn't really serve any purpose. Differential Revision: [D55811278](https://our.internmc.facebook.com/intern/diff/D55811278/) [ghstack-poisoned]
2 parents 8fa4047 + c16f220 commit 6118012

21 files changed

+869
-72
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4848
exir_ops.edge.aten.max_pool2d_with_indices.default,
4949
# Sum
5050
exir_ops.edge.aten.sum.dim_IntList,
51+
# Convolution operators
52+
exir_ops.edge.aten.convolution.default,
5153
# Other
5254
operator.getitem,
5355
]

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,27 @@ ValueRef ComputeGraph::add_tensor(
132132
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
133133
}
134134

135+
ValueRef ComputeGraph::add_tensor_like(
136+
const ValueRef vref,
137+
const api::StorageType storage_type,
138+
const api::GPUMemoryLayout memory_layout) {
139+
TensorRef& tref = get_val(vref).toTensorRef();
140+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141+
}
142+
143+
ValueRef ComputeGraph::add_tensor_like(
144+
const ValueRef vref,
145+
const api::GPUMemoryLayout memory_layout) {
146+
TensorRef& tref = get_val(vref).toTensorRef();
147+
return add_tensor(tref.sizes, tref.dtype, memory_layout);
148+
}
149+
135150
ValueRef ComputeGraph::add_tensor(
136151
const std::vector<int64_t>& sizes,
137152
const api::ScalarType dtype,
138153
const int64_t shared_object_idx) {
139154
return add_tensor(
140-
sizes,
141-
dtype,
142-
suggested_storage_type(),
143-
suggested_memory_layout(sizes),
144-
shared_object_idx);
155+
sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
145156
}
146157

147158
ValueRef ComputeGraph::add_tensorref(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class ComputeGraph final {
172172
const api::ScalarType dtype,
173173
const api::StorageType storage_type,
174174
const api::GPUMemoryLayout memory_layout,
175-
const int64_t shared_object_idx);
175+
const int64_t shared_object_idx = -1);
176176

177177
/*
178178
* Add a `vTensor` value to the graph with the specified properties. The
@@ -191,9 +191,25 @@ class ComputeGraph final {
191191
*/
192192
ValueRef add_tensor(
193193
const std::vector<int64_t>& sizes,
194-
const api::ScalarType dtype = api::ScalarType::Float,
194+
const api::ScalarType dtype,
195195
const int64_t shared_object_idx = -1);
196196

197+
/*
198+
* Add a `vTensor` value to the graph with the properties of `vref`.
199+
*/
200+
ValueRef add_tensor_like(
201+
const ValueRef vref,
202+
const api::StorageType storage_type,
203+
const api::GPUMemoryLayout memory_layout);
204+
205+
/*
206+
* Add a `vTensor` value to the graph with the properties of `vref`. The
207+
* suggested storage type will be used to construct the `vTensor`.
208+
*/
209+
ValueRef add_tensor_like(
210+
const ValueRef vref,
211+
const api::GPUMemoryLayout memory_layout);
212+
197213
/*
198214
* Add a `TensorRef` value to the graph with the specific properties. A
199215
* `TensorRef` is a reference to a `vTensor` whose data is stored in an

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ PrepackNode::PrepackNode(
4646
void PrepackNode::encode(ComputeGraph* graph) {
4747
api::Context* const context = graph->context();
4848

49-
TensorRef tref = graph->get_val(tref_).toTensorRef();
50-
vTensor packed = graph->get_val(packed_).toTensor();
49+
TensorRef& tref = graph->get_val(tref_).toTensorRef();
50+
vTensor& packed = graph->get_val(packed_).toTensor();
5151

5252
size_t numel = api::utils::multiply_integers(tref.sizes);
5353
api::StorageBuffer staging(graph->context(), tref.dtype, numel);
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
19+
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
20+
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;
21+
22+
layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
23+
uvec4 data;
24+
}
25+
out_extents;
26+
27+
layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
28+
uvec4 data;
29+
}
30+
in_extents;
31+
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
33+
ivec2 kernel_size;
34+
ivec2 stride;
35+
ivec2 padding;
36+
ivec2 dilation;
37+
}
38+
params;
39+
40+
// If fields are separated, SwiftShader cannot identify in_group_size.
41+
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
42+
ivec2 overlay_region;
43+
int in_group_size;
44+
}
45+
extra_params;
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
/*
50+
* Computes a 2D convolution. Each shader invocation calculates the output at
51+
* a single output location.
52+
*/
53+
void main() {
54+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
55+
56+
if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
57+
return;
58+
}
59+
60+
// Compute the index of the top-left element of the overlay region. Negative
61+
// indices indicate that the top-left element is in a region added by padding.
62+
const ivec2 ipos = pos.xy * params.stride - params.padding;
63+
64+
// Compute the start and end of the input indices to load. Padding is assumed
65+
// to be constant 0 padding, so reads from the padding region are skipped.
66+
const ivec2 start = max(ivec2(0), ipos);
67+
const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy));
68+
// Compute the start of the kernel based on how far we are skipping ahead when
69+
// reading the input. Note that these are "canonical" indices.
70+
ivec2 kstart = (start - ipos) / params.dilation;
71+
// During prepacking, the weight tensor was rearranged in order to optimize
72+
// for data access linearity in this shader. Therefore we need to adjust the
73+
// canonical coordinates to the corresponding index in the rearranged weight
74+
// tensor. The x-coordinate is multipled by 4 since each group of 4 channels
75+
// is folded into the X axis. The y-coordinate is offset based on the z-
76+
// coordinate because the 2D planes were stacked atop each other vertically.
77+
kstart.x *= 4;
78+
kstart.y += pos.z * params.kernel_size.y;
79+
80+
// Perform the convolution by iterating over the overlay region.
81+
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
82+
const int ic4 = extra_params.in_group_size / 4;
83+
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
84+
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
85+
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
86+
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);
87+
const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
88+
89+
// To explain the calculation below, the contents of in_texel and the
90+
// group of 4 texels loaded from kernel_in are shown:
91+
//
92+
// in_texel kernel_in
93+
// -x-> ---x--->
94+
// +---+ +----+----+----+----+
95+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
96+
// | +---+ | +----+----+----+----+
97+
// | | z | | | C0 | C1 | C2 | C3 |
98+
// z +---+ z +----+----+----+----+
99+
// | | y | | | B0 | B1 | B2 | B3 |
100+
// | +---+ | +----+----+----+----+
101+
// | x | | A0 | A1 | A2 | A3 |
102+
// +---+ +----+----+----+----+
103+
//
104+
// In the kernel_in graphic, cells sharing the same letter are from
105+
// the same batch/output channel index, and the number denotes a unique
106+
// channel index. To calculate the output texel, the following
107+
// calculation is performed:
108+
//
109+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
110+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
111+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
112+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
113+
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
114+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
115+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
116+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
117+
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
118+
//
119+
// which is expressed in the following statements.
120+
121+
sum = fma(in_texel.xxxx, texelFetch(kernel_in, ivec2(kxs.x, ky), 0), sum);
122+
sum = fma(in_texel.yyyy, texelFetch(kernel_in, ivec2(kxs.y, ky), 0), sum);
123+
sum = fma(in_texel.zzzz, texelFetch(kernel_in, ivec2(kxs.z, ky), 0), sum);
124+
sum = fma(in_texel.wwww, texelFetch(kernel_in, ivec2(kxs.w, ky), 0), sum);
125+
}
126+
}
127+
}
128+
129+
imageStore(image_out, pos, sum);
130+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: float
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
SUFFIX: half
15+
- VALUE: float
16+
SUFFIX: float
17+
shader_variants:
18+
- NAME: conv2d
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#include "indexing_utils.h"
14+
15+
layout(std430) buffer;
16+
17+
layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
18+
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
19+
${T[DTYPE]} data[];
20+
}
21+
buffer_in;
22+
23+
// Corresponds to {1,4,9,24} in the example below.
24+
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
25+
ivec4 data;
26+
}
27+
gpu_sizes;
28+
29+
// Corresponds to {3,3,7,10} in the example below.
30+
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
31+
ivec4 data;
32+
}
33+
original_sizes;
34+
35+
// Corresponds to {8,12} in the example below.
36+
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
37+
ivec2 data;
38+
}
39+
padded_sizes;
40+
41+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
42+
43+
/*
44+
* Computes special prepacking for a 2D convolution. Each shader invocation
45+
* calculates the input buffer location to read into the desired texel. This
46+
* packing was originally developed on CPU and that approach is described in the
47+
* rest of this comment. Refer to the code-level comments, for how we translate
48+
* it to GPU by reversing the steps.
49+
*
50+
* Consider an example weight tensor of size {10,7,3,3}. The following
51+
* transformations will be applied.
52+
*
53+
* 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
54+
* batches and 1 channel of padding are added, producing a tensor of size
55+
* {12,8,3,3}.
56+
* at::pad(x, {0,0,0,0,0,1,0,2}, "constant", 0);
57+
*
58+
* 2. Split the tensor along the C dim so that each split has 4 channels.
59+
* x.reshape({12,2,4,3,3});
60+
*
61+
* 3. For each split, "fold" the C dim into the W dim. Suppose the first rows
62+
* at H=0 of the split have values
63+
* 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
64+
*
65+
* where | denotes a channel boundary. Then, the goal is to combine those rows
66+
* into one row with the values
67+
* 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
68+
*
69+
* x.permute({0,1,3,4,2}).reshape({12,2,3,12});
70+
*
71+
* 4. Stack the splits belonging to the same batch horizontally by swapping the
72+
* C and H dims.
73+
* x.permute({0,2,1,3}).reshape({12,3,24});
74+
*
75+
* 5. Repeat a similar process to "fold" the N dim into the C dim. Split along
76+
* the N dim so that each split has 4 batches.
77+
* x.reshape({3,4,3,24});
78+
*
79+
* 6. Stack the batches on each other vertically by swapping the N and C dims.
80+
* x.permute({1,0,2,3}).reshape({4,9,24});
81+
*/
82+
void main() {
83+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
84+
const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data);
85+
86+
if (any(greaterThanEqual(coord, gpu_sizes.data))) {
87+
return;
88+
}
89+
90+
// As in usual staging shaders, map from GPU texel position to normal CPU
91+
// buffer indices: (24,9) -> (4,9,24)
92+
const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data);
93+
const ivec4 p0 =
94+
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);
95+
96+
// Re-map the normal CPU buffer indices to special indices, through a series
97+
// of mappings: reshape is a no-op to the underlying indices, so we only map
98+
// for pad and permute.
99+
const int Np = padded_sizes.data.y;
100+
const int Cp = padded_sizes.data.x;
101+
const int N = original_sizes.data.w;
102+
const int C = original_sizes.data.z;
103+
const int H = original_sizes.data.y;
104+
const int W = original_sizes.data.x;
105+
106+
// Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
107+
// Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
108+
// Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w)
109+
// Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w)
110+
const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Np / 4), (H * Cp * W));
111+
const ivec4 p2 = SWAP_ADJ_DIMS(p1, H, (Cp / 4), (W * 4));
112+
const ivec4 p3 = SWAP_ADJ_DIMS(p2, W, 4, 1);
113+
const ivec4 p4 = SWAP_ADJ_DIMS(p3, H, 4, W);
114+
115+
// Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
116+
// For values in the padded region, write zero instead of buffer data.
117+
const ivec4 c = p4 % (Cp * H * W) / (H * W);
118+
const ivec4 n = p4 / (Cp * H * W);
119+
const ivec4 p5 = p4 - n * (Cp - C) * H * W;
120+
const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) |
121+
ivec4(greaterThanEqual(n, ivec4(N)));
122+
123+
${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x);
124+
${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y);
125+
${T[DTYPE]} val_z = mix(buffer_in.data[p5.z], 0, mask.z);
126+
${T[DTYPE]} val_w = mix(buffer_in.data[p5.w], 0, mask.w);
127+
128+
${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w);
129+
130+
imageStore(image_out, pos.xy, texel);
131+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_prepack_weights:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
SUFFIX: half
14+
- VALUE: float
15+
SUFFIX: float
16+
shader_variants:
17+
- NAME: conv2d_prepack_weights

0 commit comments

Comments
 (0)