Skip to content

Commit decd2bf

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
Add resize function for aten.upsample_nearest2d.vec operator (#4069)
Summary: Pull Request resolved: #4069 aten.upsample_nearest2d.vec is an operator in OCR model (Neck module). We added support to not decompose this op when lowering to Vulkan delegation. Add resize function so that we only need to build the graph once to test input with different sizes. Reviewed By: copyrightly, jorgep31415 Differential Revision: D59030889
1 parent eebffae commit decd2bf

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

backends/vulkan/runtime/graph/ops/impl/Upsample.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,29 @@
1616

1717
namespace vkcompute {
1818

19+
void resize_upsample_nearest2d_node(
20+
ComputeGraph* graph,
21+
const std::vector<ArgGroup>& args,
22+
const std::vector<ValueRef>& extra_args) {
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
25+
std::vector<int64_t> out_sizes = self->sizes(); // NCHW
26+
27+
const ValueRef output_sizes = extra_args[0]; // HW
28+
const ValueRef scale_factors = extra_args[1]; // HW
29+
if (!graph->val_is_none(output_sizes)) {
30+
IntListPtr output_size_ref = graph->get_int_list(output_sizes);
31+
out_sizes.at(2) = output_size_ref->at(0);
32+
out_sizes.at(3) = output_size_ref->at(1);
33+
} else {
34+
DoubleListPtr scales = graph->get_double_list(scale_factors);
35+
out_sizes.at(2) *= scales->at(0);
36+
out_sizes.at(3) *= scales->at(1);
37+
}
38+
39+
out->virtual_resize(out_sizes);
40+
}
41+
1942
// ExecuTorch-Vulkan framework to add node
2043
// Args:
2144
// in: will be converted from NCHW input tensor to 3D ARGB representation in
@@ -87,7 +110,9 @@ void add_upsample_nearest2d_node(
87110
graph.create_params_buffer(input_size),
88111
graph.create_params_buffer(rev_scales)},
89112
// Specialization Constants
90-
{}));
113+
{},
114+
resize_upsample_nearest2d_node,
115+
{output_sizes, scale_factors}));
91116
}
92117

93118
void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {

0 commit comments

Comments
 (0)