Skip to content

Commit 41ec7fa

Browse files
authored
[ET-VK] Integrate axis mapping into staging <-> image transfer shaders
Differential Revision: D62210117 Pull Request resolved: #5093
1 parent 9739609 commit 41ec7fa

File tree

9 files changed

+136
-31
lines changed

9 files changed

+136
-31
lines changed

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_buffer(0, "w", "nchw_out", DTYPE)}
25-
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_buffer(B, "w", "nchw_out", DTYPE)}
25+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -51,7 +52,7 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
5152

5253
void main() {
5354
const ivec3 pos = ivec3(gl_GlobalInvocationID);
54-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
55+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5556

5657
if (any(greaterThanEqual(tensor_idx, sizes))) {
5758
return;

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,42 @@ ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) {
183183
return tensor_idx;
184184
}
185185

186+
/*
187+
* Derive (w,h,c,n) tensor indices from (x,y,z) texture position using axis
188+
* mapping.
189+
*/
190+
ivec4 to_tensor_idx(
191+
ivec3 pos,
192+
ivec4 sizes,
193+
const ivec4 axis_mapping,
194+
const int packed_dim) {
195+
// Align packed dim to next multiple of 4 to account for texel padding
196+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
197+
198+
// Packed dim contains 4 elements per texel, so moving 1 unit traverses 4
199+
// elements in the tensor.
200+
pos[axis_mapping[packed_dim]] *= 4;
201+
202+
ivec4 tensor_idx;
203+
for (int dim = 0; dim < 3; ++dim) {
204+
tensor_idx[dim] = pos[axis_mapping[dim]];
205+
}
206+
207+
// Early return if batch is 1. Batch index will be 0.
208+
if (sizes.w == 1) {
209+
tensor_idx.w = 0;
210+
return tensor_idx;
211+
}
212+
213+
// Else, adjust the dim that's concatenated with batch. Note that the axis
214+
// mapping for the batch dim indicates WHCN dim index of the dim that it is
215+
// concatenated with, not a texture axis.
216+
tensor_idx.w = tensor_idx[axis_mapping[3]] / sizes[axis_mapping[3]];
217+
tensor_idx[axis_mapping[3]] %= sizes[axis_mapping[3]];
218+
219+
return tensor_idx;
220+
}
221+
186222
/*
187223
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim
188224
* is packed along a texel
@@ -199,6 +235,34 @@ ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
199235
return pos;
200236
}
201237

238+
/*
239+
* Derive (x,y,z) texture position from (w,h,c,n) tensor indices using axis
240+
* mapping.
241+
*/
242+
ivec3 to_texture_pos(
243+
const ivec4 idx,
244+
ivec4 sizes,
245+
const ivec4 axis_mapping,
246+
const int packed_dim) {
247+
// Align packed dim to next multiple of 4 to account for texel padding
248+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
249+
250+
ivec3 pos;
251+
for (int dim = 0; dim < 3; ++dim) {
252+
pos[axis_mapping[dim]] = idx[dim];
253+
}
254+
255+
// Adjust batch dim if needed
256+
if (sizes.w > 1) {
257+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
258+
}
259+
260+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
261+
// tensor elements in that dim.
262+
pos[axis_mapping[packed_dim]] /= 4;
263+
return pos;
264+
}
265+
202266
/*
203267
* Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of the tensor, which dim
204268
* is packed along a texel
@@ -218,6 +282,35 @@ ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) {
218282
return pos;
219283
}
220284

285+
/*
286+
* Derive (x,y,z,i) texel element position from the (w,h,c,n) tensor index using
287+
* the axis mapping.
288+
*/
289+
ivec4 to_texture_elem_pos(
290+
const ivec4 idx,
291+
ivec4 sizes,
292+
const ivec4 axis_mapping,
293+
const int packed_dim) {
294+
// Align packed dim to next multiple of 4 to account for texel padding
295+
sizes[packed_dim] = alignup4(sizes[packed_dim]);
296+
297+
ivec4 pos;
298+
for (int dim = 0; dim < 3; ++dim) {
299+
pos[axis_mapping[dim]] = idx[dim];
300+
}
301+
302+
// Adjust batch dim if needed
303+
if (sizes.w > 1) {
304+
pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w;
305+
}
306+
307+
// Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4
308+
// tensor elements in that dim.
309+
pos[axis_mapping[packed_dim]] /= 4;
310+
pos.w = idx[packed_dim] % 4;
311+
return pos;
312+
}
313+
221314
//
222315
// Texel Access and Storage
223316
//

backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_buffer(0, "w", "nchw_out", "int")}
20-
${layout_declare_tensor(1, "r", "t_in", "int8", "texture3d")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
22-
${layout_declare_ubo(3, "int", "out_numel")}
19+
${layout_declare_buffer(B, "w", "nchw_out", "int")}
20+
${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")}
21+
${layout_declare_ubo(B, "ivec4", "tensor_sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
23+
${layout_declare_ubo(B, "int", "out_numel")}
2324

2425
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2526

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)}
2121

2222
layout(std430) buffer;
2323

24-
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
25-
${layout_declare_buffer(1, "r", "nchw_in", DTYPE)}
26-
${layout_declare_ubo(2, "ivec4", "sizes")}
24+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
25+
${layout_declare_buffer(B, "r", "nchw_in", DTYPE)}
26+
${layout_declare_ubo(B, "ivec4", "sizes")}
27+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2728

2829
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2930

@@ -53,7 +54,7 @@ VEC4_T read_texel(ivec4 tensor_idx) {
5354

5455
void main() {
5556
const ivec3 pos = ivec3(gl_GlobalInvocationID);
56-
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim);
57+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
5758
if (any(greaterThanEqual(tensor_idx, sizes))) {
5859
return;
5960
}

backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ layout(std430) buffer;
1616

1717
#extension GL_EXT_control_flow_attributes : require
1818

19-
${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")}
20-
${layout_declare_buffer(1, "r", "nchw_in", "int")}
21-
${layout_declare_ubo(2, "ivec4", "tensor_sizes")}
19+
${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")}
20+
${layout_declare_buffer(B, "r", "nchw_in", "int")}
21+
${layout_declare_ubo(B, "ivec4", "sizes")}
22+
${layout_declare_ubo(B, "ivec4", "axis_mapping")}
2223

2324
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
2425

@@ -36,7 +37,7 @@ int extend_sign(int x) {
3637

3738
ivec4 read_texel(ivec4 tensor_idx) {
3839
const ivec4 buf_indices = get_texel_nchw_buffer_ixs(
39-
tensor_idx, tensor_sizes, packed_dim);
40+
tensor_idx, sizes, packed_dim);
4041

4142
int shift = (1 << 8) - 1;
4243
ivec4 masks;
@@ -51,7 +52,7 @@ ivec4 read_texel(ivec4 tensor_idx) {
5152
ivec4 out_tex = ivec4(0);
5253

5354
[[unroll]] for (int i = 0; i < 4; ++i) {
54-
if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) {
55+
if (tensor_idx[packed_dim] + i < sizes[packed_dim]) {
5556
int in_texel = nchw_in[buf_indices[i] / 4];
5657
int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4));
5758
extracted_val = extend_sign(extracted_val);
@@ -64,9 +65,9 @@ ivec4 read_texel(ivec4 tensor_idx) {
6465

6566
void main() {
6667
const ivec3 pos = ivec3(gl_GlobalInvocationID);
67-
const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim);
68+
const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim);
6869

69-
if (any(greaterThanEqual(tensor_idx, tensor_sizes))) {
70+
if (any(greaterThanEqual(tensor_idx, sizes))) {
7071
return;
7172
}
7273

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ValueRef prepack_biases(
106106
graph.create_local_wg_size(v),
107107
vref,
108108
v,
109-
{t->sizes_ubo()},
109+
{t->sizes_ubo(), t->axis_mapping_ubo()},
110110
// Specialization constants
111111
{SV(t->packed_dim_whcn_idx())}));
112112

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void add_staging_to_tensor_node(
3131
graph.strides_ubo(out_tensor),
3232
graph.numel_ubo(out_tensor)});
3333
} else {
34-
ubos.append(graph.sizes_ubo(out_tensor));
34+
ubos.append(
35+
{graph.sizes_ubo(out_tensor), graph.axis_mapping_ubo(out_tensor)});
3536
}
3637

3738
graph.execute_nodes().emplace_back(new ExecuteNode(
@@ -69,7 +70,8 @@ void add_tensor_to_staging_node(
6970
graph.strides_ubo(in_tensor),
7071
graph.numel_ubo(in_tensor)});
7172
} else {
72-
ubos.append(graph.sizes_ubo(in_tensor));
73+
ubos.append(
74+
{graph.sizes_ubo(in_tensor), graph.axis_mapping_ubo(in_tensor)});
7375
}
7476

7577
// Normally, the image_to_nchw shader is structured so that each thread reads
@@ -113,7 +115,7 @@ ValueRef prepack(
113115
if (graph.is_buffer_storage(v)) {
114116
ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)});
115117
} else {
116-
ubos.append(graph.sizes_ubo(v));
118+
ubos.append({graph.sizes_ubo(v), graph.axis_mapping_ubo(v)});
117119
}
118120

119121
graph.prepack_nodes().emplace_back(new PrepackNode(

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ void record_nchw_to_image_op(
8585
vkapi::PipelineStage::COMPUTE,
8686
vkapi::MemoryAccessType::WRITE),
8787
src_buffer,
88-
v_dst.sizes_ubo());
88+
v_dst.sizes_ubo(),
89+
v_dst.axis_mapping_ubo());
8990
}
9091

9192
void record_image_to_nchw_op(
@@ -106,7 +107,8 @@ void record_image_to_nchw_op(
106107
0,
107108
dst_buffer,
108109
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
109-
v_src.sizes_ubo());
110+
v_src.sizes_ubo(),
111+
v_src.axis_mapping_ubo());
110112
}
111113

112114
void record_int8_image_to_nchw_noint8_op(
@@ -127,6 +129,7 @@ void record_int8_image_to_nchw_noint8_op(
127129
dst_buffer.buffer(),
128130
v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
129131
v_src.sizes_ubo(),
132+
v_src.axis_mapping_ubo(),
130133
v_src.numel_ubo());
131134
}
132135

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) {
12331233
GraphConfig config;
12341234
ComputeGraph graph(config);
12351235

1236-
std::vector<int64_t> size_big = {8, 64, 124};
1237-
std::vector<int64_t> size_small = {8, 1, 124};
1236+
std::vector<int64_t> size_big = {1, 8, 8};
1237+
std::vector<int64_t> size_small = {1, 1, 8};
12381238

12391239
// Build graph
12401240

@@ -1415,8 +1415,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14151415
/*shared_object_idx = */ 4);
14161416

14171417
// +2: t.sizes_ubo() for each staging shader
1418+
// +2: t.axis_mapping_ubo() for each staging shader
14181419
// +2: staging buffer for each input tensor
1419-
EXPECT_TRUE(get_vma_allocation_count() == 4);
1420+
EXPECT_TRUE(get_vma_allocation_count() == 6);
14201421

14211422
ValueRef c = graph.add_tensor(
14221423
size_big,
@@ -1433,8 +1434,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14331434

14341435
// +2: alpha UBO, broadcast UBO for arithmetic shader
14351436
// +1: t.sizes_ubo() uniform buffer for staging shader
1437+
// +1: t.axis_mapping_ubo() uniform buffer for staging shader
14361438
// +1: staging buffer for the input tensor
1437-
EXPECT_TRUE(get_vma_allocation_count() == 9);
1439+
EXPECT_TRUE(get_vma_allocation_count() == 12);
14381440

14391441
ValueRef e = graph.add_tensor(
14401442
size_big,
@@ -1450,14 +1452,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
14501452

14511453
// +2: alpha UBO, broadcast UBO for arithmetic shader
14521454
// +1: t.sizes_ubo() for staging shader
1455+
// +1: t.axis_mapping_ubo() for staging shader
14531456
// +1 staging buffer for the input tensor
1454-
EXPECT_TRUE(get_vma_allocation_count() == 13);
1457+
EXPECT_TRUE(get_vma_allocation_count() == 17);
14551458

14561459
graph.prepare();
14571460
graph.encode_execute();
14581461

14591462
// +3: shared memory allocations for tensors
1460-
EXPECT_TRUE(get_vma_allocation_count() == 16);
1463+
EXPECT_TRUE(get_vma_allocation_count() == 20);
14611464

14621465
// Run graph
14631466

0 commit comments

Comments
 (0)