Skip to content

Commit f549d97

Browse files
committed
[ET-VK] Integrate axis mapping into staging <-> image transfer shaders
Pull Request resolved: #5093 ## Context Building on the previous diff, this diff integrates axis mapping into staging <-> image transfer shaders. Alternative versions of indexing utility functions are introduced to account for axis mapping. The impact of shader latency of using axis mapping on transfer shaders is examined in the next diff. Differential Revision: [D62210117](https://our.internmc.facebook.com/intern/diff/D62210117/) ghstack-source-id: 241249802
1 parent 8b4db6a commit f549d97

File tree

9 files changed

+136
-31
lines changed

9 files changed

+136
-31
lines changed

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_buffer(0, "w", "nchw_out", DTYPE)}
25-
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_buffer(B, "w", "nchw_out", DTYPE)}
25+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -51,7 +52,7 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
5152

5253
void main() {
5354
const ivec3 pos = ivec3(gl_GlobalInvocationID);
54-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
55+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5556

5657
if (any(greaterThanEqual(tensor_idx, sizes))) {
5758
return;

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,42 @@ ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) {
183183
return tensor_idx;
184184
}
185185

186+
/*
187+
* Derive (w,h,c,n) tensor indices from (x,y,z) texture position using axis
188+
* mapping.
189+
*/
190+
ivec4 to_tensor_idx(
191+
ivec3 pos,
192+
ivec4 sizes,
193+
const ivec4 axis_mapping,
194+
const int packed_dim) {
195+
// Align packed dim to next multiple of 4 to account for texel padding
196+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
197+
198+
// Packed dim contains 4 elements per texel, so moving 1 unit traverses 4
199+
// elements in the tensor.
200+
pos[axis_mapping[packed_dim]] *= 4;
201+
202+
ivec4 tensor_idx;
203+
for (int dim = 0; dim < 3; ++dim) {
204+
tensor_idx[dim] = pos[axis_mapping[dim]];
205+
}
206+
207+
// Early return if batch is 1. Batch index will be 0.
208+
if (sizes.w == 1) {
209+
tensor_idx.w = 0;
210+
return tensor_idx;
211+
}
212+
213+
// Else, adjust the dim that's concatenated with batch. Note that the axis
214+
// mapping for the batch dim indicates WHCN dim index of the dim that it is
215+
// concatenated with, not a texture axis.
216+
tensor_idx.w = tensor_idx[axis_mapping[3]] / sizes[axis_mapping[3]];
217+
tensor_idx[axis_mapping[3]] %= sizes[axis_mapping[3]];
218+
219+
return tensor_idx;
220+
}
221+
186222
/*
187223
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim
188224
* is packed along a texel
@@ -199,6 +235,34 @@ ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
199235
return pos;
200236
}
201237

238+
/*
239+
* Derive (x,y,z) texture position from (w,h,c,n) tensor indices using axis
240+
* mapping.
241+
*/
242+
ivec3 to_texture_pos(
243+
const ivec4 idx,
244+
ivec4 sizes,
245+
const ivec4 axis_mapping,
246+
const int packed_dim) {
247+
// Align packed dim to next multiple of 4 to account for texel padding
248+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
249+
250+
ivec3 pos;
251+
for (int dim = 0; dim < 3; ++dim) {
252+
pos[axis_mapping[dim]] = idx[dim];
253+
}
254+
255+
// Adjust batch dim if needed
256+
if (sizes.w > 1) {
257+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
258+
}
259+
260+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
261+
// tensor elements in that dim.
262+
pos[axis_mapping[packed_dim]] /= 4;
263+
return pos;
264+
}
265+
202266
/*
203267
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of the tensor, which dim
204268
* is packed along a texel
@@ -218,6 +282,35 @@ ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
218282
return pos;
219283
}
220284

285+
/*
286+
* Derive (x,y,z,i) texel element position from the (w,h,c,n) tensor index using
287+
* the axis mapping.
288+
*/
289+
ivec4 to_texture_elem_pos(
290+
const ivec4 idx,
291+
ivec4 sizes,
292+
const ivec4 axis_mapping,
293+
const int packed_dim) {
294+
// Align packed dim to next multiple of 4 to account for texel padding
295+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
296+
297+
ivec4 pos;
298+
for (int dim = 0; dim < 3; ++dim) {
299+
pos[axis_mapping[dim]] = idx[dim];
300+
}
301+
302+
// Adjust batch dim if needed
303+
if (sizes.w > 1) {
304+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
305+
}
306+
307+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
308+
// tensor elements in that dim.
309+
pos[axis_mapping[packed_dim]] /= 4;
310+
pos.w = idx[packed_dim] % 4;
311+
return pos;
312+
}
313+
221314
//
222315
// Texel Access and Storage
223316
//

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_buffer(0, "w", "nchw_out", "int")}
20-
${layout_declare_tensor(1, "r", "t_in", "int8", "texture3d")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22-
${layout_declare_ubo(3, "int", "out_numel")}
19+
${layout_declare_buffer(B, "w", "nchw_out", "int")}
20+
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
21+
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
23+
${layout_declare_ubo(B, "int", "out_numel")}
2324

2425
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2526

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25-
${layout_declare_buffer(1, "r", "nchw_in", DTYPE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_buffer(B, "r", "nchw_in", DTYPE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -53,7 +54,7 @@ VEC4_T read_texel(ivec4 tensor_idx) {
5354

5455
void main() {
5556
const ivec3 pos = ivec3(gl_GlobalInvocationID);
56-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
57+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5758
if (any(greaterThanEqual(tensor_idx, sizes))) {
5859
return;
5960
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")}
20-
${layout_declare_buffer(1, "r", "nchw_in", "int")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
19+
${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
20+
${layout_declare_buffer(B, "r", "nchw_in", "int")}
21+
${layout_declare_ubo(B, "ivec4", "sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2223

2324
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2425

@@ -36,7 +37,7 @@ int extend_sign(int x) {
3637

3738
ivec4 read_texel(ivec4 tensor_idx) {
3839
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
39-
tensor_idx, tensor_sizes, packed_dim);
40+
tensor_idx, sizes, packed_dim);
4041

4142
int shift = (1 << 8) - 1;
4243
ivec4 masks;
@@ -51,7 +52,7 @@ ivec4 read_texel(ivec4 tensor_idx) {
5152
ivec4 out_tex = ivec4(0);
5253

5354
[[unroll]] for (int i = 0; i < 4; ++i) {
54-
if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) {
55+
if (tensor_idx[packed_dim] + i < sizes[packed_dim]) {
5556
int in_texel = nchw_in[buf_indices[i] / 4];
5657
int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4));
5758
extracted_val = extend_sign(extracted_val);
@@ -64,9 +65,9 @@ ivec4 read_texel(ivec4 tensor_idx) {
6465

6566
void main() {
6667
const ivec3 pos = ivec3(gl_GlobalInvocationID);
67-
const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim);
68+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
6869

69-
if (any(greaterThanEqual(tensor_idx, tensor_sizes))) {
70+
if (any(greaterThanEqual(tensor_idx, sizes))) {
7071
return;
7172
}
7273

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ValueRef prepack_biases(
106106
graph.create_local_wg_size(v),
107107
vref,
108108
v,
109-
{t->sizes_ubo()},
109+
{t->sizes_ubo(), t->axis_mapping_ubo()},
110110
// Specialization constants
111111
{SV(t->packed_dim_whcn_idx())}));
112112

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void add_staging_to_tensor_node(
3131
graph.strides_ubo(out_tensor),
3232
graph.numel_ubo(out_tensor)});
3333
} else {
34-
ubos.append(graph.sizes_ubo(out_tensor));
34+
ubos.append(
35+
{graph.sizes_ubo(out_tensor), graph.axis_mapping_ubo(out_tensor)});
3536
}
3637

3738
graph.execute_nodes().emplace_back(new ExecuteNode(
@@ -69,7 +70,8 @@ void add_tensor_to_staging_node(
6970
graph.strides_ubo(in_tensor),
7071
graph.numel_ubo(in_tensor)});
7172
} else {
72-
ubos.append(graph.sizes_ubo(in_tensor));
73+
ubos.append(
74+
{graph.sizes_ubo(in_tensor), graph.axis_mapping_ubo(in_tensor)});
7375
}
7476

7577
// Normally, the image_to_nchw shader is structured so that each thread reads
@@ -113,7 +115,7 @@ ValueRef prepack(
113115
if (graph.is_buffer_storage(v)) {
114116
ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)});
115117
} else {
116-
ubos.append(graph.sizes_ubo(v));
118+
ubos.append({graph.sizes_ubo(v), graph.axis_mapping_ubo(v)});
117119
}
118120

119121
graph.prepack_nodes().emplace_back(new PrepackNode(

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ void record_nchw_to_image_op(
8585
vkapi::PipelineStage::COMPUTE,
8686
vkapi::MemoryAccessType::WRITE),
8787
src_buffer,
88-
v_dst.sizes_ubo());
88+
v_dst.sizes_ubo(),
89+
v_dst.axis_mapping_ubo());
8990
}
9091

9192
void record_image_to_nchw_op(
@@ -106,7 +107,8 @@ void record_image_to_nchw_op(
106107
0,
107108
dst_buffer,
108109
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
109-
v_src.sizes_ubo());
110+
v_src.sizes_ubo(),
111+
v_src.axis_mapping_ubo());
110112
}
111113

112114
void record_int8_image_to_nchw_noint8_op(
@@ -127,6 +129,7 @@ void record_int8_image_to_nchw_noint8_op(
127129
dst_buffer.buffer(),
128130
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
129131
v_src.sizes_ubo(),
132+
v_src.axis_mapping_ubo(),
130133
v_src.numel_ubo());
131134
}
132135

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,8 +1227,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {
12271227
GraphConfig config;
12281228
ComputeGraph graph(config);
12291229

1230-
std::vector<int64_t> size_big = {8, 64, 124};
1231-
std::vector<int64_t> size_small = {8, 1, 124};
1230+
std::vector<int64_t> size_big = {1, 8, 8};
1231+
std::vector<int64_t> size_small = {1, 1, 8};
12321232

12331233
// Build graph
12341234

@@ -1409,8 +1409,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14091409
/*shared_object_idx = */ 4);
14101410

14111411
// +2: t.sizes_ubo() for each staging shader
1412+
// +2: t.axis_mapping_ubo() for each staging shader
14121413
// +2: staging buffer for each input tensor
1413-
EXPECT_TRUE(get_vma_allocation_count() == 4);
1414+
EXPECT_TRUE(get_vma_allocation_count() == 6);
14141415

14151416
ValueRef c = graph.add_tensor(
14161417
size_big,
@@ -1427,8 +1428,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14271428

14281429
// +2: alpha UBO, broadcast UBO for arithmetic shader
14291430
// +1: t.sizes_ubo() uniform buffer for staging shader
1431+
// +1: t.axis_mapping_ubo() uniform buffer for staging shader
14301432
// +1: staging buffer for the input tensor
1431-
EXPECT_TRUE(get_vma_allocation_count() == 9);
1433+
EXPECT_TRUE(get_vma_allocation_count() == 12);
14321434

14331435
ValueRef e = graph.add_tensor(
14341436
size_big,
@@ -1444,14 +1446,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14441446

14451447
// +2: alpha UBO, broadcast UBO for arithmetic shader
14461448
// +1: t.sizes_ubo() for staging shader
1449+
// +1: t.axis_mapping_ubo() for staging shader
14471450
// +1 staging buffer for the input tensor
1448-
EXPECT_TRUE(get_vma_allocation_count() == 13);
1451+
EXPECT_TRUE(get_vma_allocation_count() == 17);
14491452

14501453
graph.prepare();
14511454
graph.encode_execute();
14521455

14531456
// +3: shared memory allocations for tensors
1454-
EXPECT_TRUE(get_vma_allocation_count() == 16);
1457+
EXPECT_TRUE(get_vma_allocation_count() == 20);
14551458

14561459
// Run graph
14571460

0 commit comments

Comments
 (0)