|
16 | 16 |
|
17 | 17 | namespace vkcompute {
|
18 | 18 |
|
| 19 | +void resize_upsample_nearest2d_node( |
| 20 | + ComputeGraph* graph, |
| 21 | + const std::vector<ArgGroup>& args, |
| 22 | + const std::vector<ValueRef>& extra_args) { |
| 23 | + vTensorPtr out = graph->get_tensor(args[0].refs[0]); |
| 24 | + vTensorPtr self = graph->get_tensor(args[1].refs[0]); |
| 25 | + std::vector<int64_t> out_sizes = self->sizes(); // NCHW |
| 26 | + |
| 27 | + const ValueRef output_sizes = extra_args[0]; // HW |
| 28 | + const ValueRef scale_factors = extra_args[1]; // HW |
| 29 | + if (!graph->val_is_none(output_sizes)) { |
| 30 | + IntListPtr output_size_ref = graph->get_int_list(output_sizes); |
| 31 | + out_sizes.at(2) = output_size_ref->at(0); |
| 32 | + out_sizes.at(3) = output_size_ref->at(1); |
| 33 | + } else { |
| 34 | + DoubleListPtr scales = graph->get_double_list(scale_factors); |
| 35 | + out_sizes.at(2) *= scales->at(0); |
| 36 | + out_sizes.at(3) *= scales->at(1); |
| 37 | + } |
| 38 | + |
| 39 | + out->virtual_resize(out_sizes); |
| 40 | +} |
| 41 | + |
19 | 42 | // ExecuTorch-Vulkan framework to add node
|
20 | 43 | // Args:
|
21 | 44 | // in: will be converted from NCHW input tensor to 3D ARGB representation in
|
@@ -87,7 +110,9 @@ void add_upsample_nearest2d_node(
|
87 | 110 | graph.create_params_buffer(input_size),
|
88 | 111 | graph.create_params_buffer(rev_scales)},
|
89 | 112 | // Specialization Constants
|
90 |
| - {})); |
| 113 | + {}, |
| 114 | + resize_upsample_nearest2d_node, |
| 115 | + {output_sizes, scale_factors})); |
91 | 116 | }
|
92 | 117 |
|
93 | 118 | void upsample(ComputeGraph& graph, const std::vector<ValueRef>& args) {
|
|
0 commit comments