Skip to content

Commit 6f933c5

Browse files
committed
[ET-VK] Simplifying conv1d op shader by changing it to process one output texel per thread.
This diff changes conv1d shader to process one output texel per thread, increasing GPU occupancy and improve performance. Differential Revision: [D74097560](https://our.internmc.facebook.com/intern/diff/D74097560/) ghstack-source-id: 281752381 Pull Request resolved: #10665
1 parent 34e676d commit 6f933c5

File tree

2 files changed

+59
-61
lines changed

2 files changed

+59
-61
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
5656
// weight = (out_C, in_C / G, K),
5757
// bias = (out_C,).
5858
//
59-
// This implementation performs out_C shader invocations, where each invocation
59+
// This implementation performs N x out_C x out_L shader invocations, where each invocation
6060
// calculates the rolling kernel of the length dimension for each batch, i.e.,
61-
// computes out_L * N results.
61+
// computes out_L results.
6262
//
6363
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
6464
// shader invocations, where each invocation computes 1 result. But that
@@ -70,61 +70,53 @@ void main() {
7070
return;
7171
}
7272

73-
int in_length = in_sizes.x;
74-
int batch_size = in_sizes.z;
75-
7673
// "out_c" is the output's channel index where we write our result.
7774
// Across shader invocations, this is the only value that varies.
78-
int out_c = lpos.y;
79-
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
75+
const int out_c = lpos.y;
8076

8177
// "in_c" tracks the input's channel start index.
8278
// We iterate over the input group that corresponds to the output group.
83-
int c_start = (out_c / out_group_size) * in_group_size;
84-
int c_end = c_start + in_group_size;
79+
const int c_start = (out_c / out_group_size) * in_group_size;
80+
const int c_end = c_start + in_group_size;
81+
82+
// "out_l" tracks the output's length index where we write our result.
83+
const int out_l = lpos.x;
84+
85+
// "N" is the batch index
86+
const int N = lpos.z;
8587

8688
// "in_l" tracks the input's length start index for our input-kernel overlay
8789
// region.
88-
int l_start = -padding;
89-
int l_end = in_length + padding - dilation * (kernel_size - 1);
90-
91-
// Since the input/output tensors are channel-packed, which is along the
92-
// batch dimension, we can batch-read/write four elements at a time.
93-
for (int n = 0; n < batch_size; n += 4) {
94-
// "out_l" tracks the output's length index where we write our result.
95-
int out_l = 0;
96-
97-
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
98-
VEC4_T sum = VEC4_T(0);
99-
100-
for (int in_c = c_start; in_c < c_end; ++in_c) {
101-
// "k" tracks the kernel's index for our input-kernel computation.
102-
// It reads out-of-bound zeros, but trying to avoid them complicates
103-
// for-loop conditions, which results in worse performance.
104-
105-
// The weight tensor is channel-packed. It may not be trival choice for
106-
// performance reason since need to have more data fetch. The reason is
107-
// for some sequence model, we found that the weight tensor
108-
// (out_channel, in_channel / group, kernel) often has a large
109-
// out_channel >> kernel, leading to non-optimal use of memory as the
110-
// weight tensor gets very deep. As a mitigation, we use channel-packing
111-
// for the weight tensor, yielding a 75% reduction in weight-tensor
112-
// memory.
113-
114-
// It is possible to further reduce the memory footprint by swapping the
115-
// dimensions, using x extent for out_channel, and y for kernel.
116-
for (int k = 0; k < kernel_size; k += 1) {
117-
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
118-
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119-
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
120-
121-
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
122-
sum = fma(weight, load_texel(t_in, in_pos), sum);
123-
}
124-
}
125-
126-
const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
127-
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
90+
const int in_l = out_l * stride - padding;
91+
VEC4_T sum = VEC4_T(0);
92+
93+
for (int in_c = c_start; in_c < c_end; ++in_c) {
94+
// "k" tracks the kernel's index for our input-kernel computation.
95+
// It reads out-of-bound zeros, but trying to avoid them complicates
96+
// for-loop conditions, which results in worse performance.
97+
98+
// The weight tensor is channel-packed. It may not be trival choice for
99+
// performance reason since need to have more data fetch. The reason is
100+
// for some sequence model, we found that the weight tensor
101+
// (out_channel, in_channel / group, kernel) often has a large
102+
// out_channel >> kernel, leading to non-optimal use of memory as the
103+
// weight tensor gets very deep. As a mitigation, we use channel-packing
104+
// for the weight tensor, yielding a 75% reduction in weight-tensor
105+
// memory.
106+
107+
// It is possible to further reduce the memory footprint by swapping the
108+
// dimensions, using x extent for out_channel, and y for kernel.
109+
for (int k = 0; k < kernel_size; k++) {
110+
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
111+
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
112+
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
113+
114+
const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map);
115+
sum = fma(weight, load_texel(t_in, in_pos), sum);
128116
}
129117
}
118+
119+
const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
120+
const ivec3 out_lpos = ivec3(out_l, out_c, N);
121+
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
130122
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -505,17 +505,23 @@ void add_conv1d_node(
505505

506506
check_conv_args(*t_in, *t_out);
507507

508-
int32_t in_channels = in_sizes.at(1);
509-
int32_t out_channels = weight_sizes.at(0);
510-
int32_t kernel_size = weight_sizes.at(2);
511-
int32_t stride_size = graph.get_int_list(stride)->at(0);
512-
int32_t padding_size = graph.get_int_list(padding)->at(0);
513-
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
514-
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
515-
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
516-
517-
utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
518-
utils::uvec3 local_size = {1, 64, 1};
508+
const int32_t in_channels = in_sizes.at(1);
509+
const int32_t out_channels = weight_sizes.at(0);
510+
const int32_t kernel_size = weight_sizes.at(2);
511+
const int32_t stride_size = graph.get_int_list(stride)->at(0);
512+
const int32_t padding_size = graph.get_int_list(padding)->at(0);
513+
const int32_t dilation_size = graph.get_int_list(dilation)->at(0);
514+
const int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
515+
const int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
516+
517+
const utils::uvec3 global_size = {
518+
// out length
519+
graph.size_at<uint32_t>(-1, out),
520+
// out channels
521+
static_cast<uint32_t>(out_channels),
522+
// out batches
523+
graph.size_at<uint32_t>(-3, out)};
524+
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
519525

520526
Kernel1dParams kernel_params = {
521527
kernel_size,
@@ -525,7 +531,7 @@ void add_conv1d_node(
525531
in_group_size,
526532
out_group_size};
527533

528-
OutputParams out_params = {out_min_val, out_max_val};
534+
const OutputParams out_params = {out_min_val, out_max_val};
529535

530536
std::string kernel_name("conv1d");
531537
if (clamp_out) {

0 commit comments

Comments
 (0)