Skip to content

Commit 3020133

Browse files
committed
Update base for Update on "[ET-VK][Ops] aten.convolution (Pointwise)"
We port an optimization from ATen-VK for specific weight sizes: [`conv2d_pw.glsl`](https://github.com/pytorch/pytorch/blob/09c72eaa3f69f90402c86a30abf4fc621298578c/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl) Differential Revision: [D55814587](https://our.internmc.facebook.com/intern/diff/D55814587/) [ghstack-poisoned]
1 parent 6c1ab36 commit 3020133

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ void resize_conv2d_node(
3434
if (ndim == 4) {
3535
new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4);
3636
}
37-
const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
37+
const auto& weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes;
3838
new_out_sizes.at(ndim - 3) =
3939
transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
4040

4141
// Height, Width
42-
const auto new_out_sizes_hw = calc_out_sizes_hw(
42+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
4343
*graph,
4444
self.sizes(),
4545
extra_args[0],
@@ -97,7 +97,7 @@ api::ShaderInfo get_conv2d_shader(
9797
case Conv2dMethod::Depthwise:
9898
kernel_name << "conv2d_dw";
9999
if (!prepack_weights) {
100-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
100+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
101101
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
102102
kernel_name << "_output_tile_3x3";
103103
}
@@ -167,7 +167,7 @@ ValueRef prepack_weights(
167167
const ValueRef vref,
168168
const Conv2dMethod method) {
169169
const auto original_sizes = graph.get_val(vref).toTensorRef().sizes;
170-
const auto final_sizes = get_final_sizes(original_sizes, method);
170+
const auto& final_sizes = get_final_sizes(original_sizes, method);
171171

172172
ValueRef v = graph.add_tensor(
173173
final_sizes,
@@ -182,7 +182,7 @@ ValueRef prepack_weights(
182182
api::ShaderInfo shader =
183183
get_conv2d_shader(graph, t, /*prepack_weights = */ true, method, vref);
184184

185-
const auto padded_sizes = get_padded_sizes(original_sizes, method);
185+
const auto& padded_sizes = get_padded_sizes(original_sizes, method);
186186

187187
graph.prepack_nodes().emplace_back(new PrepackNode(
188188
graph,
@@ -221,13 +221,13 @@ Conv2dParams create_conv2d_params(
221221
const ValueRef weight,
222222
const KernelParams& p,
223223
const bool transposed) {
224-
const auto overlay_region = api::utils::make_ivec2({
224+
const auto& overlay_region = api::utils::make_ivec2({
225225
p.kernel_size.data[0] +
226226
(p.kernel_size.data[0] - 1) * (p.dilation.data[0] - 1),
227227
p.kernel_size.data[1] +
228228
(p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1),
229229
});
230-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
230+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
231231
const int32_t in_group_size =
232232
api::utils::safe_downcast<int32_t>(api::utils::align_up(
233233
transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4)));
@@ -255,7 +255,7 @@ Conv2dMethod get_conv2d_method(
255255
const ValueRef weight,
256256
const int64_t groups,
257257
const bool transposed) {
258-
const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes;
258+
const auto& weight_sizes = graph.get_val(weight).toTensorRef().sizes;
259259
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
260260
return Conv2dMethod::Depthwise;
261261
}

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void resize_max_pool2d_node(
3535
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3636

3737
// Height, Width
38-
const auto new_out_sizes_hw = calc_out_sizes_hw(
38+
const auto& new_out_sizes_hw = calc_out_sizes_hw(
3939
*graph,
4040
self.sizes(),
4141
extra_args[0],

0 commit comments

Comments
 (0)