Skip to content

Commit 4c74c0e

Browse files
committed
Update on "[ET-VK][Ops] aten.convolution (Pointwise)"
We port an optimization from ATen-VK for specific weight sizes: [`conv2d_pw.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl) Differential Revision: [D55814587](https://our.internmc.facebook.com/intern/diff/D55814587/) [ghstack-poisoned]
2 parents be5d23b + 3020133 commit 4c74c0e

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ void resize_conv2d_node(
3434
if (ndim == 4) {
3535
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3636
}
37-
const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
37+
const auto& weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
3838
new_out_sizes.at(ndim - 3) =
3939
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
4040

4141
// Height, Width
42-
const auto new_out_sizes_hw = calc_out_sizes_hw(
42+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
4343
*graph,
4444
self.sizes(),
4545
extra_args[0],
@@ -98,7 +98,7 @@ api::ShaderInfo get_conv2d_shader(
9898
case Conv2dMethod::Depthwise:
9999
kernel_name << "conv2d_dw";
100100
if (!prepack_weights) {
101-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
101+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
102102
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
103103
kernel_name << "_output_tile_3x3";
104104
}
@@ -177,7 +177,7 @@ ValueRef prepack_weights(
177177
const ValueRef vref,
178178
const Conv2dMethod method) {
179179
const auto original_sizes = graph.get_val(vref).toTensorRef().sizes;
180-
const auto final_sizes = get_final_sizes(original_sizes, method);
180+
const auto& final_sizes = get_final_sizes(original_sizes, method);
181181

182182
ValueRef v = graph.add_tensor(
183183
final_sizes,
@@ -192,7 +192,7 @@ ValueRef prepack_weights(
192192
api::ShaderInfo shader =
193193
get_conv2d_shader(graph, t, /*prepack_weights = */ true, method, vref);
194194

195-
const auto padded_sizes = get_padded_sizes(original_sizes, method);
195+
const auto& padded_sizes = get_padded_sizes(original_sizes, method);
196196

197197
graph.prepack_nodes().emplace_back(new PrepackNode(
198198
graph,
@@ -231,13 +231,13 @@ Conv2dParams create_conv2d_params(
231231
const ValueRef weight,
232232
const KernelParams& p,
233233
const bool transposed) {
234-
const auto overlay_region = api::utils::make_ivec2({
234+
const auto& overlay_region = api::utils::make_ivec2({
235235
p.kernel_size.data[0] +
236236
(p.kernel_size.data[0] - 1) * (p.dilation.data[0] - 1),
237237
p.kernel_size.data[1] +
238238
(p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1),
239239
});
240-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
240+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
241241
const int32_t in_group_size =
242242
api::utils::safe_downcast<int32_t>(api::utils::align_up(
243243
transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4)));
@@ -265,7 +265,7 @@ Conv2dMethod get_conv2d_method(
265265
const ValueRef weight,
266266
const int64_t groups,
267267
const bool transposed) {
268-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
268+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
269269
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
270270
return Conv2dMethod::Depthwise;
271271
}

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void resize_max_pool2d_node(
3535
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3636

3737
// Height, Width
38-
const auto new_out_sizes_hw = calc_out_sizes_hw(
38+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
3939
*graph,
4040
self.sizes(),
4141
extra_args[0],

0 commit comments

Comments
 (0)