Skip to content

Commit fd55298

Browse files
committed
[ET-VK] Integrate axis mapping into staging <-> image transfer shaders
Pull Request resolved: #5093 ## Context Building on the previous diff, this diff integrates axis mapping into staging <-> image transfer shaders. Alternative versions of indexing utility functions are introduced to account for axis mapping. The impact of shader latency of using axis mapping on transfer shaders is examined in the next diff. ghstack-source-id: 241282078 Differential Revision: [D62210117](https://our.internmc.facebook.com/intern/diff/D62210117/)
1 parent d5a6c3a commit fd55298

File tree

9 files changed

+136
-31
lines changed

9 files changed

+136
-31
lines changed

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_buffer(0, "w", "nchw_out", DTYPE)}
25-
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_buffer(B, "w", "nchw_out", DTYPE)}
25+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -51,7 +52,7 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
5152

5253
void main() {
5354
const ivec3 pos = ivec3(gl_GlobalInvocationID);
54-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
55+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5556

5657
if (any(greaterThanEqual(tensor_idx, sizes))) {
5758
return;

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,42 @@ ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) {
183183
return tensor_idx;
184184
}
185185

186+
/*
187+
* Derive (w,h,c,n) tensor indices from (x,y,z) texture position using axis
188+
* mapping.
189+
*/
190+
ivec4 to_tensor_idx(
191+
ivec3 pos,
192+
ivec4 sizes,
193+
const ivec4 axis_mapping,
194+
const int packed_dim) {
195+
// Align packed dim to next multiple of 4 to account for texel padding
196+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
197+
198+
// Packed dim contains 4 elements per texel, so moving 1 unit traverses 4
199+
// elements in the tensor.
200+
pos[axis_mapping[packed_dim]] *= 4;
201+
202+
ivec4 tensor_idx;
203+
for (int dim = 0; dim < 3; ++dim) {
204+
tensor_idx[dim] = pos[axis_mapping[dim]];
205+
}
206+
207+
// Early return if batch is 1. Batch index will be 0.
208+
if (sizes.w == 1) {
209+
tensor_idx.w = 0;
210+
return tensor_idx;
211+
}
212+
213+
// Else, adjust the dim that's concatenated with batch. Note that the axis
214+
// mapping for the batch dim indicates WHCN dim index of the dim that it is
215+
// concatenated with, not a texture axis.
216+
tensor_idx.w = tensor_idx[axis_mapping[3]] / sizes[axis_mapping[3]];
217+
tensor_idx[axis_mapping[3]] %= sizes[axis_mapping[3]];
218+
219+
return tensor_idx;
220+
}
221+
186222
/*
187223
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim
188224
* is packed along a texel
@@ -199,6 +235,34 @@ ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
199235
return pos;
200236
}
201237

238+
/*
239+
* Derive (x,y,z) texture position from (w,h,c,n) tensor indices using axis
240+
* mapping.
241+
*/
242+
ivec3 to_texture_pos(
243+
const ivec4 idx,
244+
ivec4 sizes,
245+
const ivec4 axis_mapping,
246+
const int packed_dim) {
247+
// Align packed dim to next multiple of 4 to account for texel padding
248+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
249+
250+
ivec3 pos;
251+
for (int dim = 0; dim < 3; ++dim) {
252+
pos[axis_mapping[dim]] = idx[dim];
253+
}
254+
255+
// Adjust batch dim if needed
256+
if (sizes.w > 1) {
257+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
258+
}
259+
260+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
261+
// tensor elements in that dim.
262+
pos[axis_mapping[packed_dim]] /= 4;
263+
return pos;
264+
}
265+
202266
/*
203267
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of the tensor, which dim
204268
* is packed along a texel
@@ -218,6 +282,35 @@ ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
218282
return pos;
219283
}
220284

285+
/*
286+
* Derive (x,y,z,i) texel element position from the (w,h,c,n) tensor index using
287+
* the axis mapping.
288+
*/
289+
ivec4 to_texture_elem_pos(
290+
const ivec4 idx,
291+
ivec4 sizes,
292+
const ivec4 axis_mapping,
293+
const int packed_dim) {
294+
// Align packed dim to next multiple of 4 to account for texel padding
295+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
296+
297+
ivec4 pos;
298+
for (int dim = 0; dim < 3; ++dim) {
299+
pos[axis_mapping[dim]] = idx[dim];
300+
}
301+
302+
// Adjust batch dim if needed
303+
if (sizes.w > 1) {
304+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
305+
}
306+
307+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
308+
// tensor elements in that dim.
309+
pos[axis_mapping[packed_dim]] /= 4;
310+
pos.w = idx[packed_dim] % 4;
311+
return pos;
312+
}
313+
221314
//
222315
// Texel Access and Storage
223316
//

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_buffer(0, "w", "nchw_out", "int")}
20-
${layout_declare_tensor(1, "r", "t_in", "int8", "texture3d")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22-
${layout_declare_ubo(3, "int", "out_numel")}
19+
${layout_declare_buffer(B, "w", "nchw_out", "int")}
20+
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
21+
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
23+
${layout_declare_ubo(B, "int", "out_numel")}
2324

2425
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2526

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25-
${layout_declare_buffer(1, "r", "nchw_in", DTYPE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_buffer(B, "r", "nchw_in", DTYPE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -53,7 +54,7 @@ VEC4_T read_texel(ivec4 tensor_idx) {
5354

5455
void main() {
5556
const ivec3 pos = ivec3(gl_GlobalInvocationID);
56-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
57+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5758
if (any(greaterThanEqual(tensor_idx, sizes))) {
5859
return;
5960
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")}
20-
${layout_declare_buffer(1, "r", "nchw_in", "int")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
19+
${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
20+
${layout_declare_buffer(B, "r", "nchw_in", "int")}
21+
${layout_declare_ubo(B, "ivec4", "sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2223

2324
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2425

@@ -36,7 +37,7 @@ int extend_sign(int x) {
3637

3738
ivec4 read_texel(ivec4 tensor_idx) {
3839
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
39-
tensor_idx, tensor_sizes, packed_dim);
40+
tensor_idx, sizes, packed_dim);
4041

4142
int shift = (1 << 8) - 1;
4243
ivec4 masks;
@@ -51,7 +52,7 @@ ivec4 read_texel(ivec4 tensor_idx) {
5152
ivec4 out_tex = ivec4(0);
5253

5354
[[unroll]] for (int i = 0; i < 4; ++i) {
54-
if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) {
55+
if (tensor_idx[packed_dim] + i < sizes[packed_dim]) {
5556
int in_texel = nchw_in[buf_indices[i] / 4];
5657
int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4));
5758
extracted_val = extend_sign(extracted_val);
@@ -64,9 +65,9 @@ ivec4 read_texel(ivec4 tensor_idx) {
6465

6566
void main() {
6667
const ivec3 pos = ivec3(gl_GlobalInvocationID);
67-
const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim);
68+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
6869

69-
if (any(greaterThanEqual(tensor_idx, tensor_sizes))) {
70+
if (any(greaterThanEqual(tensor_idx, sizes))) {
7071
return;
7172
}
7273

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ValueRef prepack_biases(
106106
graph.create_local_wg_size(v),
107107
vref,
108108
v,
109-
{t->sizes_ubo()},
109+
{t->sizes_ubo(), t->axis_mapping_ubo()},
110110
// Specialization constants
111111
{SV(t->packed_dim_whcn_idx())}));
112112

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void add_staging_to_tensor_node(
3131
graph.strides_ubo(out_tensor),
3232
graph.numel_ubo(out_tensor)});
3333
} else {
34-
ubos.append(graph.sizes_ubo(out_tensor));
34+
ubos.append(
35+
{graph.sizes_ubo(out_tensor), graph.axis_mapping_ubo(out_tensor)});
3536
}
3637

3738
graph.execute_nodes().emplace_back(new ExecuteNode(
@@ -69,7 +70,8 @@ void add_tensor_to_staging_node(
6970
graph.strides_ubo(in_tensor),
7071
graph.numel_ubo(in_tensor)});
7172
} else {
72-
ubos.append(graph.sizes_ubo(in_tensor));
73+
ubos.append(
74+
{graph.sizes_ubo(in_tensor), graph.axis_mapping_ubo(in_tensor)});
7375
}
7476

7577
// Normally, the image_to_nchw shader is structured so that each thread reads
@@ -113,7 +115,7 @@ ValueRef prepack(
113115
if (graph.is_buffer_storage(v)) {
114116
ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)});
115117
} else {
116-
ubos.append(graph.sizes_ubo(v));
118+
ubos.append({graph.sizes_ubo(v), graph.axis_mapping_ubo(v)});
117119
}
118120

119121
graph.prepack_nodes().emplace_back(new PrepackNode(

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ void record_nchw_to_image_op(
8585
vkapi::PipelineStage::COMPUTE,
8686
vkapi::MemoryAccessType::WRITE),
8787
src_buffer,
88-
v_dst.sizes_ubo());
88+
v_dst.sizes_ubo(),
89+
v_dst.axis_mapping_ubo());
8990
}
9091

9192
void record_image_to_nchw_op(
@@ -106,7 +107,8 @@ void record_image_to_nchw_op(
106107
0,
107108
dst_buffer,
108109
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
109-
v_src.sizes_ubo());
110+
v_src.sizes_ubo(),
111+
v_src.axis_mapping_ubo());
110112
}
111113

112114
void record_int8_image_to_nchw_noint8_op(
@@ -127,6 +129,7 @@ void record_int8_image_to_nchw_noint8_op(
127129
dst_buffer.buffer(),
128130
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
129131
v_src.sizes_ubo(),
132+
v_src.axis_mapping_ubo(),
130133
v_src.numel_ubo());
131134
}
132135

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {
12331233
GraphConfig config;
12341234
ComputeGraph graph(config);
12351235

1236-
std::vector<int64_t> size_big = {8, 64, 124};
1237-
std::vector<int64_t> size_small = {8, 1, 124};
1236+
std::vector<int64_t> size_big = {1, 8, 8};
1237+
std::vector<int64_t> size_small = {1, 1, 8};
12381238

12391239
// Build graph
12401240

@@ -1415,8 +1415,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14151415
/*shared_object_idx = */ 4);
14161416

14171417
// +2: t.sizes_ubo() for each staging shader
1418+
// +2: t.axis_mapping_ubo() for each staging shader
14181419
// +2: staging buffer for each input tensor
1419-
EXPECT_TRUE(get_vma_allocation_count() == 4);
1420+
EXPECT_TRUE(get_vma_allocation_count() == 6);
14201421

14211422
ValueRef c = graph.add_tensor(
14221423
size_big,
@@ -1433,8 +1434,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14331434

14341435
// +2: alpha UBO, broadcast UBO for arithmetic shader
14351436
// +1: t.sizes_ubo() uniform buffer for staging shader
1437+
// +1: t.axis_mapping_ubo() uniform buffer for staging shader
14361438
// +1: staging buffer for the input tensor
1437-
EXPECT_TRUE(get_vma_allocation_count() == 9);
1439+
EXPECT_TRUE(get_vma_allocation_count() == 12);
14381440

14391441
ValueRef e = graph.add_tensor(
14401442
size_big,
@@ -1450,14 +1452,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14501452

14511453
// +2: alpha UBO, broadcast UBO for arithmetic shader
14521454
// +1: t.sizes_ubo() for staging shader
1455+
// +1: t.axis_mapping_ubo() for staging shader
14531456
// +1 staging buffer for the input tensor
1454-
EXPECT_TRUE(get_vma_allocation_count() == 13);
1457+
EXPECT_TRUE(get_vma_allocation_count() == 17);
14551458

14561459
graph.prepare();
14571460
graph.encode_execute();
14581461

14591462
// +3: shared memory allocations for tensors
1460-
EXPECT_TRUE(get_vma_allocation_count() == 16);
1463+
EXPECT_TRUE(get_vma_allocation_count() == 20);
14611464

14621465
// Run graph
14631466

0 commit comments

Comments
 (0)