Skip to content

[ET-VK] Simplifying conv1d op shader by changing it to process one output texel per thread. #10690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 41 additions & 53 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -56,75 +56,63 @@ const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
// weight = (out_C, in_C / G, K),
// bias = (out_C,).
//
// This implementation performs out_C shader invocations, where each invocation
// This implementation performs N x out_C x out_L shader invocations, where each invocation
// calculates the rolling kernel of the length dimension for each batch, i.e.,
// computes out_L * N results.
//
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
// shader invocations, where each invocation computes 1 result. But that
// performs worse.
// computes out_L results.
void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(lpos, out_limits))) {
return;
}

int in_length = in_sizes.x;
int batch_size = in_sizes.z;

// "out_c" is the output's channel index where we write our result.
// Across shader invocations, this is the only value that varies.
int out_c = lpos.y;
VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
const int out_c = lpos.y;

// "in_c" tracks the input's channel start index.
// We iterate over the input group that corresponds to the output group.
int c_start = (out_c / out_group_size) * in_group_size;
int c_end = c_start + in_group_size;
const int c_start = (out_c / out_group_size) * in_group_size;
const int c_end = c_start + in_group_size;

// "out_l" tracks the output's length index where we write our result.
const int out_l = lpos.x;

// "N" is the batch index
const int N = lpos.z;

// "in_l" tracks the input's length start index for our input-kernel overlay
// region.
int l_start = -padding;
int l_end = in_length + padding - dilation * (kernel_size - 1);

// Since the input/output tensors are channel-packed, which is along the
// batch dimension, we can batch-read/write four elements at a time.
for (int n = 0; n < batch_size; n += 4) {
// "out_l" tracks the output's length index where we write our result.
int out_l = 0;

for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
VEC4_T sum = VEC4_T(0);

for (int in_c = c_start; in_c < c_end; ++in_c) {
// "k" tracks the kernel's index for our input-kernel computation.
// It reads out-of-bound zeros, but trying to avoid them complicates
// for-loop conditions, which results in worse performance.

// The weight tensor is channel-packed. It may not be trival choice for
// performance reason since need to have more data fetch. The reason is
// for some sequence model, we found that the weight tensor
// (out_channel, in_channel / group, kernel) often has a large
// out_channel >> kernel, leading to non-optimal use of memory as the
// weight tensor gets very deep. As a mitigation, we use channel-packing
// for the weight tensor, yielding a 75% reduction in weight-tensor
// memory.

// It is possible to further reduce the memory footprint by swapping the
// dimensions, using x extent for out_channel, and y for kernel.
for (int k = 0; k < kernel_size; k += 1) {
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);

ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
sum = fma(weight, load_texel(t_in, in_pos), sum);
}
}

const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
const int in_l = out_l * stride - padding;
VEC4_T sum = VEC4_T(0);

for (int in_c = c_start; in_c < c_end; ++in_c) {
// "k" tracks the kernel's index for our input-kernel computation.
// It reads out-of-bound zeros, but trying to avoid them complicates
// for-loop conditions, which results in worse performance.

// The weight tensor is channel-packed. It may not be trival choice for
// performance reason since need to have more data fetch. The reason is
// for some sequence model, we found that the weight tensor
// (out_channel, in_channel / group, kernel) often has a large
// out_channel >> kernel, leading to non-optimal use of memory as the
// weight tensor gets very deep. As a mitigation, we use channel-packing
// for the weight tensor, yielding a 75% reduction in weight-tensor
// memory.

// It is possible to further reduce the memory footprint by swapping the
// dimensions, using x extent for out_channel, and y for kernel.
for (int k = 0; k < kernel_size; k++) {
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);

const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map);
sum = fma(weight, load_texel(t_in, in_pos), sum);
}
}

const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
const ivec3 out_lpos = ivec3(out_l, out_c, N);
write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
}
31 changes: 19 additions & 12 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,17 +505,24 @@ void add_conv1d_node(

check_conv_args(*t_in, *t_out);

int32_t in_channels = in_sizes.at(1);
int32_t out_channels = weight_sizes.at(0);
int32_t kernel_size = weight_sizes.at(2);
int32_t stride_size = graph.get_int_list(stride)->at(0);
int32_t padding_size = graph.get_int_list(padding)->at(0);
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);

utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
utils::uvec3 local_size = {1, 64, 1};
const int32_t in_channels = in_sizes.at(1);
const int32_t out_channels = weight_sizes.at(0);
const int32_t kernel_size = weight_sizes.at(2);
const int32_t stride_size = graph.get_int_list(stride)->at(0);
const int32_t padding_size = graph.get_int_list(padding)->at(0);
const int32_t dilation_size = graph.get_int_list(dilation)->at(0);
const int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
const int32_t out_group_size =
static_cast<int64_t>(out_channels / groups_val);

const utils::uvec3 global_size = {
// out length
graph.size_at<uint32_t>(-1, out),
// out channels
static_cast<uint32_t>(out_channels),
// out batches
utils::div_up_4(graph.size_at<uint32_t>(-3, out))};
const utils::uvec3 local_size = graph.create_local_wg_size(global_size);

Kernel1dParams kernel_params = {
kernel_size,
Expand All @@ -525,7 +532,7 @@ void add_conv1d_node(
in_group_size,
out_group_size};

OutputParams out_params = {out_min_val, out_max_val};
const OutputParams out_params = {out_min_val, out_max_val};

std::string kernel_name("conv1d");
if (clamp_out) {
Expand Down
Loading