Skip to content

conv1d general case #3223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 87 additions & 61 deletions backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,78 +21,104 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels {
int data;
}
out_channels;

layout(set = 0, binding = 5) uniform PRECISION restrict In_length {
int data;
}
in_length;

layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size {
int data;
}
kernel_size;
layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};

layout(set = 0, binding = 5) uniform PRECISION restrict InSizes {
ivec4 in_sizes;
};

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
int kernel_size;
int stride;
int padding;
int dilation;
int in_group_size;
int out_group_size;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* This implementation optimize for simplicity (and partially performance) for a
* (1, C, L) where C == groups. Hence we only focus on calculating the rolling
* kernel of the L dimension.
*/
// Let us define
//
// input = (N, in_C, in_L),
// output = (N, out_C, out_L),
// groups = G,
// kernel = K,
//
// which results in shapes
//
// weight = (out_C, in_C / G, K),
// bias = (out_C,).
//
// This implementation performs out_C shader invocations, where each invocation
// calculates the rolling kernel of the length dimension for each batch, i.e.,
// computes out_L * N results.
//
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
// shader invocations, where each invocation computes 1 result. But that
// performs worse.
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

// The global workgroup should have taken care of it. We only perform one
// work item for each 1d tensor on lengths
if (pos.x >= 1) {
if (any(greaterThanEqual(pos, out_limits))) {
return;
}

int c = pos.y;
if (c >= out_channels.data) {
return;
}

// Assume n = 1, do not handle n > 1 case for now.
int n = pos.z;
if (n >= 1) {
return;
}

vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0);

for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) {
vec4 v = vec4(0);
for (int k = 0; k < kernel_size.data; ++k) {
const ivec3 in_pos = ivec3(i+k, c, 0);
const vec4 input_value = texelFetch(image_in, in_pos, 0);

// Note that we are reading weight in the inner loop, this could be
// improved by moving it before the outer loop. Since the weight vector is
// contant for the entire call.

// weight in input-space: (c, 0, k);
// notice that c is 4-packed. We need to mod 4 to get the actual weight.
const ivec3 w_pos = ivec3(k, 0, c / 4);
const vec4 weight = texelFetch(kernel_in, w_pos, 0);

float w = weight.x;
if (c % 4 == 1) {
w = weight.y;
} else if (c % 4 == 2) {
w = weight.z;
} else if (c % 4 == 3) {
w = weight.w;
int in_length = in_sizes.x;
int batch_size = in_sizes.z;

// "out_c" is the output's channel index where we write our result.
// Across shader invocations, this is the only value that varies.
int out_c = pos.y;
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);

// "in_c" tracks the input's channel start index.
// We iterate over the input group that corresponds to the output group.
int c_start = (out_c / out_group_size) * in_group_size;
int c_end = c_start + in_group_size;

// "in_l" tracks the input's length start index for our input-kernel overlay
// region.
int l_start = -padding;
int l_end = in_length + padding - dilation * (kernel_size - 1);

// Since the input/output tensors are channel-packed, which is along the
// batch dimension, we can batch-read/write four elements at a time.
for (int n = 0; n < batch_size; n += 4) {
// "out_l" tracks the output's length index where we write our result.
int out_l = 0;

for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
vec4 sum = vec4(0);

for (int in_c = c_start; in_c < c_end; ++in_c) {
// "k" tracks the kernel's index for our input-kernel computation.
// It reads out-of-bound zeros, but trying to avoid them complicates
// for-loop conditions, which results in worse performance.
for (int k = 0; k < kernel_size; k += 4) {
// Since the weight tensor is width-packed, which is along the length
// dimension, we can batch-read four elements at a time.
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
const vec4 weight = texelFetch(kernel_in, w_pos, 0);

const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);

const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);

const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);

const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
}
}

v += w * input_value.x;
ivec3 out_pos = ivec3(out_l, out_c, n / 4);
imageStore(image_out, out_pos, sum + bias.x);
}

ivec3 out_pos = ivec3(i, c, 0);
imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0));
}
}
78 changes: 37 additions & 41 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ void resize_conv1d_node(
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);

int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);

const std::vector<int64_t>& weight_sizes = weight_ref->sizes;

const std::vector<int64_t>& in_sizes = self->sizes();
Expand All @@ -71,8 +76,9 @@ void resize_conv1d_node(
int64_t in_length = in_sizes.at(2);

new_out_sizes.at(0) = in_sizes.at(0);
new_out_sizes.at(1) = in_sizes.at(1);
new_out_sizes.at(2) = in_length - kernel_size + 1;
new_out_sizes.at(1) = weight_sizes.at(0);
new_out_sizes.at(2) = calc_out_size(
in_length, kernel_size, stride_size, padding_size, dilation_size, false);

out->virtual_resize(new_out_sizes);
}
Expand Down Expand Up @@ -244,10 +250,6 @@ ValueRef prepack_weights(
}

void check_conv_args(const vTensor& in, const vTensor& out) {
if (in.sizes().at(0) > 1) {
VK_THROW(
"aten.convolution.default: input batch size > 1 is not supported yet!");
}
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
}
Expand All @@ -260,7 +262,7 @@ struct Conv2dParams final {
Conv2dParams create_conv2d_params(
ComputeGraph& graph,
const ValueRef weight,
const KernelParams& p,
const Kernel2dParams& p,
const bool transposed) {
const auto& overlay_region = api::utils::make_ivec2({
p.kernel_size.data[0] +
Expand All @@ -275,7 +277,7 @@ Conv2dParams create_conv2d_params(
return {overlay_region, in_group_size};
}

void check_conv2d_params(const KernelParams& p, const bool transposed) {
void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
if (transposed) {
if (p.dilation.data[0] > 1 || p.dilation.data[1] > 1) {
VK_THROW(
Expand Down Expand Up @@ -342,12 +344,15 @@ void add_conv2d_node(

vTensorPtr t_in = graph.get_tensor(arg_in);
vTensorPtr t_out = graph.get_tensor(out);
if (t_in->sizes().at(0) > 1) {
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
}
check_conv_args(*t_in, *t_out);

api::utils::uvec3 global_size = t_out->extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

KernelParams kernel_params = create_kernel_params(
Kernel2dParams kernel_params = create_kernel2d_params(
graph,
weight,
/*kernel_size_only = */ false,
Expand Down Expand Up @@ -395,8 +400,7 @@ void add_conv1d_node(
const ValueRef groups,
const ValueRef out) {
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
ValueRef arg_weight =
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
ValueRef arg_weight = prepack_if_tensor_ref(graph, weight, api::kWidthPacked);
ValueRef arg_bias = prepack_biases(
graph,
bias,
Expand All @@ -414,37 +418,29 @@ void add_conv1d_node(
std::vector<int64_t> in_sizes = t_in->sizes();
std::vector<int64_t> weight_sizes = t_weight->sizes();
std::vector<int64_t> out_sizes = t_out->sizes();
IntListPtr stride_sizes = graph.get_int_list(stride);
IntListPtr padding_sizes = graph.get_int_list(padding);
IntListPtr dilation_sizes = graph.get_int_list(dilation);
int64_t weight_out_channels = weight_sizes.at(0);
int64_t kernel_size = weight_sizes.at(2);
int64_t in_length = in_sizes.at(2);

VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor");
VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor");
VK_CHECK_COND(
stride_sizes->size() == 1 && stride_sizes->at(0) == 1,
"stride must be 1");
VK_CHECK_COND(
padding_sizes->size() == 1 && padding_sizes->at(0) == 0,
"padding must be 0");
VK_CHECK_COND(
dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1,
"dilation must be 1");
VK_CHECK_COND(
groups_val == in_sizes.at(1), "groups must be equal to in_channels");
VK_CHECK_COND(
groups_val == weight_sizes.at(0),
"groups must be equal to weight_sizes.at(0)");
VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1");

check_conv_args(*t_in, *t_out);

api::utils::uvec3 global_size = {
1, static_cast<uint32_t>(weight_out_channels), 1};
int32_t in_channels = in_sizes.at(1);
int32_t out_channels = weight_sizes.at(0);
int32_t kernel_size = weight_sizes.at(2);
int32_t stride_size = graph.get_int_list(stride)->at(0);
int32_t padding_size = graph.get_int_list(padding)->at(0);
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);

api::utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
api::utils::uvec3 local_size = {1, 1, 1};

Kernel1dParams kernel_params = {
kernel_size,
stride_size,
padding_size,
dilation_size,
in_group_size,
out_group_size};

std::string kernel_name("conv1d");
kernel_name.reserve(kShaderNameReserve);

Expand All @@ -460,15 +456,15 @@ void add_conv1d_node(
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
// Shader params buffers
{
graph.create_params_buffer(weight_out_channels),
graph.create_params_buffer(in_length),
graph.create_params_buffer(kernel_size),
t_out->texture_limits_ubo(),
t_in->sizes_ubo(),
graph.create_params_buffer(kernel_params),
},
// Specialization Constants
{},
// Resizing Logic
resize_conv1d_node,
{weight}));
{weight, stride, padding, dilation}));
}

void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void add_max_pool2d_node(
std::string kernel_name("max_pool2d");
add_dtype_suffix(kernel_name, *t_out);

KernelParams kernel_params = create_kernel_params(
Kernel2dParams kernel_params = create_kernel2d_params(
graph,
kernel_size,
/*kernel_size_only = */ true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ api::utils::ivec2 make_ivec2_kernel_size(
}
}

KernelParams create_kernel_params(
Kernel2dParams create_kernel2d_params(
ComputeGraph& graph,
const ValueRef weight,
const bool kernel_size_only,
Expand Down
21 changes: 19 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,38 @@

namespace vkcompute {

struct KernelParams final {
struct Kernel2dParams final {
api::utils::ivec2 kernel_size;
api::utils::ivec2 stride;
api::utils::ivec2 padding;
api::utils::ivec2 dilation;
};

KernelParams create_kernel_params(
struct Kernel1dParams final {
int kernel_size;
int stride;
int padding;
int dilation;
int in_group_size;
int out_group_size;
};

Kernel2dParams create_kernel2d_params(
ComputeGraph& graph,
const ValueRef weight,
const bool kernel_size_only,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation);

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel_size,
const int64_t stride,
const int64_t padding,
const int64_t dilation,
const bool ceil_mode);

std::vector<int64_t> calc_out_sizes_hw(
ComputeGraph& graph,
const std::vector<int64_t>& in_sizes,
Expand Down
Loading