Skip to content

[ET-VK] Add support for aten::upsample_bilinear2d ATen op #10363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def register_view_op(features: OpFeatures):
exir_ops.edge.aten.ones.default,
exir_ops.edge.aten.ones_like.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.zeros.default,
exir_ops.edge.aten.zeros_like.default,
exir_ops.edge.et_vk.grid_priors.default,
Expand Down
71 changes: 71 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/upsample_2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "ivec3", "in_limits")}
${layout_declare_ubo(B, "vec2", "recip_scales")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int align_corners = 0;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

ivec2 max_in_xy = in_limits.xy - 1;
vec2 scaled_xy;

if (align_corners == 1) {
scaled_xy = pos.xy * recip_scales;
} else {
scaled_xy = (pos.xy + 0.5) * recip_scales - 0.5;
}

$if MODE == "nearest":
const ivec2 ipos = clamp(ivec2(round(scaled_xy)), ivec2(0), max_in_xy);
VEC4_T out_tex = texelFetch(t_in, ivec3(ipos, pos.z), 0);
$elif MODE == "bilinear":
vec2 upper_xy = ceil(scaled_xy);
vec2 lower_xy = floor(scaled_xy);

// Clamp coordinates to valid input range
upper_xy = clamp(upper_xy, ivec2(0), max_in_xy);
lower_xy = clamp(lower_xy, ivec2(0), max_in_xy);

// Calculate interpolation weights
vec2 interp_weights = (scaled_xy - lower_xy);

// Sample the four nearest texels
VEC4_T sample00 = texelFetch(t_in, ivec3(lower_xy.x, lower_xy.y, pos.z), 0);
VEC4_T sample10 = texelFetch(t_in, ivec3(upper_xy.x, lower_xy.y, pos.z), 0);
VEC4_T sample01 = texelFetch(t_in, ivec3(lower_xy.x, upper_xy.y, pos.z), 0);
VEC4_T sample11 = texelFetch(t_in, ivec3(upper_xy.x, upper_xy.y, pos.z), 0);

// Perform bilinear interpolation
VEC4_T out_tex = mix(
mix(sample00, sample10, interp_weights.x),
mix(sample01, sample11, interp_weights.x),
interp_weights.y
);

imageStore(t_out, pos, out_tex);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

upsample_nearest2d:
upsample_2d:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
MODE: nearest
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: upsample_nearest2d
- NAME: upsample_bilinear2d
MODE: bilinear
39 changes: 0 additions & 39 deletions backends/vulkan/runtime/graph/ops/glsl/upsample_nearest2d.glsl

This file was deleted.

117 changes: 80 additions & 37 deletions backends/vulkan/runtime/graph/ops/impl/Upsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

namespace vkcompute {

enum class UpsampleMode : int { NEAREST, BILINEAR };

void resize_upsample_nearest2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
Expand All @@ -39,19 +41,12 @@ void resize_upsample_nearest2d_node(
out->virtual_resize(out_sizes);
}

// ExecuTorch-Vulkan framework to add node
// Args:
// in: will be converted from NCHW input tensor to 3D ARGB representation in
// openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
// output size of H and W dimensions. >= input sizes;

// will be computed if only given the scale_factors.
// scale_factors: optional 2D array of scale factors for H and W dimensions.
// Will be computed if only given the output_sizes.
void add_upsample_nearest2d_node(
ComputeGraph& graph,
const UpsampleMode mode,
const ValueRef in,
const ValueRef output_sizes,
const ValueRef align_corners,
const ValueRef scale_factors,
const ValueRef out) {
if (graph.val_is_none(output_sizes) && graph.val_is_none(scale_factors)) {
Expand All @@ -63,36 +58,61 @@ void add_upsample_nearest2d_node(
"Invalid input, must provide ONLY one of output_sizes or scale_factors");
}

vTensorPtr t_in = graph.get_tensor(in);
utils::uvec3 input_sizes = t_in->logical_limits();
int align_corners_val = 0;
if (is_valid(align_corners) && graph.get_bool(align_corners)) {
align_corners_val = 1;
}

utils::uvec3 in_limits = graph.logical_limits_of(in);
utils::uvec3 out_limits = graph.logical_limits_of(out);

uint32_t out_width = out_limits[0u];
uint32_t out_height = out_limits[1u];

utils::ivec2 input_size = {
utils::safe_downcast<int32_t>(input_sizes[0]),
utils::safe_downcast<int32_t>(input_sizes[1])};
utils::vec2 rev_scales = {
utils::safe_downcast<float>(1.0), utils::safe_downcast<float>(1.0)};
float scale_factor_x = float(in_limits[0u]) / float(out_width);
float scale_factor_y = float(in_limits[1u]) / float(out_height);

float recip_scale_factor_x = 1.0f / scale_factor_x;
float recip_scale_factor_y = 1.0f / scale_factor_y;

// Reverse scale factors that pre-computed before GLSL.
if (!graph.val_is_none(output_sizes)) {
auto output_size_ref = graph.get_int_list(output_sizes);
rev_scales = {
utils::safe_downcast<float>(
(float)input_size[0] / output_size_ref->at(1)),
utils::safe_downcast<float>(
(float)input_size[1] / output_size_ref->at(0))};
IntListPtr output_size_ref = graph.get_int_list(output_sizes);
out_width = output_size_ref->at(1);
out_height = output_size_ref->at(0);

VK_CHECK_COND(out_width == out_limits[0u]);
VK_CHECK_COND(out_height == out_limits[1u]);

} else {
DoubleListPtr scales = graph.get_double_list(scale_factors);
scale_factor_x = scales->at(1);
scale_factor_y = scales->at(0);

VK_CHECK_COND(in_limits[0u] * scale_factor_x == out_width);
VK_CHECK_COND(in_limits[1u] * scale_factor_y == out_height);
}

if (align_corners_val == 1) {
recip_scale_factor_x = float(in_limits[0u] - 1) / float(out_width - 1);
recip_scale_factor_y = float(in_limits[1u] - 1) / float(out_height - 1);
} else {
auto scales = graph.get_double_list(scale_factors);
rev_scales = {
utils::safe_downcast<float>(1.0 / scales->at(1)),
utils::safe_downcast<float>(1.0 / scales->at(0))};
recip_scale_factor_x = float(in_limits[0u]) / float(out_width);
recip_scale_factor_y = float(in_limits[1u]) / float(out_height);
}

vTensorPtr t_out = graph.get_tensor(out);
utils::vec2 recip_scales = {recip_scale_factor_x, recip_scale_factor_y};

std::string kernel_name("upsample_nearest2d");
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, *t_out);
switch (mode) {
case UpsampleMode::NEAREST:
kernel_name = "upsample_nearest2d";
break;
case UpsampleMode::BILINEAR:
kernel_name = "upsample_bilinear2d";
break;
}
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
Expand All @@ -103,21 +123,44 @@ void add_upsample_nearest2d_node(
{{out, vkapi::MemoryAccessType::WRITE},
{in, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{t_out->logical_limits_ubo(),
graph.create_params_buffer(input_size),
graph.create_params_buffer(rev_scales)},
{graph.logical_limits_ubo(out),
graph.logical_limits_ubo(in),
graph.create_params_buffer(recip_scales)},
// Specialization Constants
{},
{align_corners_val},
resize_upsample_nearest2d_node,
{output_sizes, scale_factors}));
}

void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_upsample_nearest2d_node(graph, args[0], args[1], args[2], args[3]);
void upsample_nearest2d(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
return add_upsample_nearest2d_node(
graph,
UpsampleMode::NEAREST,
args[0],
args[1],
kDummyValueRef,
args[2],
args[3]);
}

void upsample_bilinear2d(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
return add_upsample_nearest2d_node(
graph,
UpsampleMode::BILINEAR,
args[0],
args[1],
args[2],
args[3],
args[4]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample);
VK_REGISTER_OP(aten.upsample_nearest2d.vec, upsample_nearest2d);
VK_REGISTER_OP(aten.upsample_bilinear2d.vec, upsample_bilinear2d);
}

} // namespace vkcompute
41 changes: 27 additions & 14 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,21 +430,34 @@ def get_native_layer_norm_inputs():
return test_suite


@register_test_suite("aten.upsample_nearest2d.vec")
def get_upsample_inputs():
test_suite = VkTestSuite(
[
# (input tensor shape, output 2D image size (H, W), output scaling factors)
((2, 2, 2, 2), None, [1, 1]),
((1, 1, 2, 2), None, [2, 2]),
((1, 1, 2, 2), None, [2, 4]),
((1, 1, 2, 2), None, [4, 2]),
((1, 1, 2, 2), [2, 2], None),
((1, 1, 2, 2), [2, 4], None),
((1, 1, 2, 2), [3, 2], None),
]
)
return test_suite
inputs_list = [
# (input tensor shape, output 2D image size (H, W), output scaling factors)
((2, 2, 2, 2), None, [1, 1]),
((1, 1, 2, 2), None, [2, 2]),
((1, 1, 2, 2), None, [2, 4]),
((1, 1, 2, 2), None, [4, 2]),
((1, 1, 2, 2), [2, 2], None),
((1, 1, 2, 2), [2, 4], None),
((1, 1, 2, 2), [3, 2], None),
]
return inputs_list


@register_test_suite("aten.upsample_nearest2d.vec")
def get_upsample_nearest2d_inputs():
inputs_list = get_upsample_inputs()
return VkTestSuite(inputs_list)


@register_test_suite("aten.upsample_bilinear2d.vec")
def get_upsample_bilinear2d_inputs():
base_inputs_list = get_upsample_inputs()
inputs_list = []
for input_case in base_inputs_list:
inputs_list.append((input_case[0], input_case[1], False, input_case[2]))
inputs_list.append((input_case[0], input_case[1], True, input_case[2]))
return VkTestSuite(inputs_list)


@register_test_suite(["aten.full.default", "aten.full_like.default"])
Expand Down
Loading