Skip to content

[ET-VK] Refactor Pool.cpp #2836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 14 additions & 33 deletions backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,23 @@ void resize_max_pool2d_node(
size_t ndim = self.sizes().size();
std::vector<int64_t> new_out_sizes(ndim);

// Batch
// Batch, Channel
if (ndim == 4) {
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
}
// Channel
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);

const auto kernel_size = reverse(*graph, extra_args[0]);
const auto stride = reverse(*graph, extra_args[1]);
const auto padding = reverse(*graph, extra_args[2]);
const auto dilation = reverse(*graph, extra_args[3]);
const bool ceil_mode = graph->get_val(extra_args[4]).toBool();

// Height
new_out_sizes.at(ndim - 2) = calc_out_size(
self.sizes().at(ndim - 2),
kernel_size.data[1],
stride.data[1],
padding.data[1],
dilation.data[1],
ceil_mode);
// Width
new_out_sizes.at(ndim - 1) = calc_out_size(
self.sizes().at(ndim - 1),
kernel_size.data[0],
stride.data[0],
padding.data[0],
dilation.data[0],
ceil_mode);

VK_CHECK_COND(new_out_sizes.at(ndim - 2) >= 1);
VK_CHECK_COND(new_out_sizes.at(ndim - 1) >= 1);
// Height, Width
const auto new_out_sizes_hw = calc_out_sizes_hw(
*graph,
self.sizes(),
extra_args[0],
extra_args[1],
extra_args[2],
extra_args[3],
extra_args[4]);
new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);

out.virtual_resize(new_out_sizes);
indices.virtual_resize(new_out_sizes);
Expand Down Expand Up @@ -96,12 +81,8 @@ void add_max_pool2d_node(
kernel_name << "max_pool2d";
apply_dtype_suffix(kernel_name, t_out);

KernelParams kernel_params{
reverse(graph, kernel_size),
reverse(graph, stride),
reverse(graph, padding),
reverse(graph, dilation),
};
KernelParams kernel_params =
create_kernel_params(graph, kernel_size, stride, padding, dilation);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
Expand Down
65 changes: 60 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,80 @@

namespace vkcompute {

api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) {
return api::utils::make_ivec2(
graph.get_val(vref).toIntList(), /*reverse = */ true);
}

KernelParams create_kernel_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation) {
return {
make_ivec2_from_list(graph, kernel_size),
make_ivec2_from_list(graph, stride),
make_ivec2_from_list(graph, padding),
make_ivec2_from_list(graph, dilation),
};
}

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel,
const int64_t kernel_size,
const int64_t stride,
const int64_t padding,
const int64_t dilation,
const bool ceil_mode) {
int64_t c = ceil_mode ? stride - 1 : 0;
int64_t out_size =
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
(in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + c) / stride +
1;
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
--out_size;
}
return out_size;
}

api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref) {
return api::utils::make_ivec2(
graph.get_val(vref).toIntList(), /*reverse=*/true);
std::vector<int64_t> calc_out_sizes_hw(
ComputeGraph& graph,
const std::vector<int64_t>& in_sizes,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef ceil_mode) {
const int64_t ndim = in_sizes.size();
std::vector<int64_t> out_sizes(2);

const auto kernel_vec = make_ivec2_from_list(graph, kernel_size);
const auto stride_vec = make_ivec2_from_list(graph, stride);
const auto padding_vec = make_ivec2_from_list(graph, padding);
const auto dilation_vec = make_ivec2_from_list(graph, dilation);
const bool ceil_mode_val = graph.get_val(ceil_mode).toBool();

// Height
out_sizes.at(0) = calc_out_size(
in_sizes.at(ndim - 2),
kernel_vec.data[1],
stride_vec.data[1],
padding_vec.data[1],
dilation_vec.data[1],
ceil_mode_val);
// Width
out_sizes.at(1) = calc_out_size(
in_sizes.at(ndim - 1),
kernel_vec.data[0],
stride_vec.data[0],
padding_vec.data[0],
dilation_vec.data[0],
ceil_mode_val);

VK_CHECK_COND(out_sizes.at(0) >= 1);
VK_CHECK_COND(out_sizes.at(1) >= 1);

return out_sizes;
}

} // namespace vkcompute
24 changes: 15 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,20 @@ struct KernelParams final {
api::utils::ivec2 dilation;
};

int64_t calc_out_size(
const int64_t in_size,
const int64_t kernel_size,
const int64_t stride,
const int64_t padding,
const int64_t dilation,
const bool ceil_mode);

api::utils::ivec2 reverse(ComputeGraph& graph, ValueRef vref);
KernelParams create_kernel_params(
ComputeGraph& graph,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation);

std::vector<int64_t> calc_out_sizes_hw(
ComputeGraph& graph,
const std::vector<int64_t>& in_sizes,
const ValueRef kernel_size,
const ValueRef stride,
const ValueRef padding,
const ValueRef dilation,
const ValueRef ceil_mode);

} // namespace vkcompute