Skip to content

Commit af49578

Browse files
committed
Update on "[ET-VK] Refactor Pool.cpp"
This change adds more lines than it subtracts, but it'll be worth it once we reuse the methods for `aten.convolution`. Differential Revision: [D55706057](https://our.internmc.facebook.com/intern/diff/D55706057/) [ghstack-poisoned]
2 parents 4336380 + a913947 commit af49578

File tree

6 files changed

+58
-64
lines changed

6 files changed

+58
-64
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,6 @@ ValueRef ComputeGraph::add_tensor(
123123
return idx;
124124
}
125125

126-
ValueRef ComputeGraph::add_tensor(
127-
TensorRef& tref,
128-
const api::StorageType storage_type,
129-
const api::GPUMemoryLayout memory_layout) {
130-
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
131-
}
132-
133126
ValueRef ComputeGraph::add_tensor(
134127
const std::vector<int64_t>& sizes,
135128
const api::ScalarType dtype,
@@ -139,9 +132,18 @@ ValueRef ComputeGraph::add_tensor(
139132
sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
140133
}
141134

142-
ValueRef ComputeGraph::add_tensor(
143-
TensorRef& tref,
135+
ValueRef ComputeGraph::add_tensor_like(
136+
const ValueRef vref,
137+
const api::StorageType storage_type,
138+
const api::GPUMemoryLayout memory_layout) {
139+
TensorRef& tref = get_val(vref).toTensorRef();
140+
return add_tensor(tref.sizes, tref.dtype, storage_type, memory_layout);
141+
}
142+
143+
ValueRef ComputeGraph::add_tensor_like(
144+
const ValueRef vref,
144145
const api::GPUMemoryLayout memory_layout) {
146+
TensorRef& tref = get_val(vref).toTensorRef();
145147
return add_tensor(tref.sizes, tref.dtype, memory_layout);
146148
}
147149

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,6 @@ class ComputeGraph final {
174174
const api::GPUMemoryLayout memory_layout,
175175
const int64_t shared_object_idx = -1);
176176

177-
/*
178-
* Add a `vTensor` value to the graph with the properties of `tref`.
179-
*/
180-
ValueRef add_tensor(
181-
TensorRef& tref,
182-
const api::StorageType storage_type,
183-
const api::GPUMemoryLayout memory_layout);
184-
185177
/*
186178
* Add a `vTensor` value to the graph with the specified properties. The
187179
* suggested storage type will be used to construct the `vTensor`.
@@ -192,14 +184,6 @@ class ComputeGraph final {
192184
const api::GPUMemoryLayout memory_layout,
193185
const int64_t shared_object_idx = -1);
194186

195-
/*
196-
* Add a `vTensor` value to the graph with the properties of `tref`. The
197-
* suggested storage type will be used to construct the `vTensor`.
198-
*/
199-
ValueRef add_tensor(
200-
TensorRef& tref,
201-
const api::GPUMemoryLayout memory_layout);
202-
203187
/*
204188
* Add a `vTensor` value to the graph with the specified properties. The
205189
* suggested storage type and memory layout will be used to construct the
@@ -210,6 +194,22 @@ class ComputeGraph final {
210194
const api::ScalarType dtype,
211195
const int64_t shared_object_idx = -1);
212196

197+
/*
198+
* Add a `vTensor` value to the graph with the properties of `vref`.
199+
*/
200+
ValueRef add_tensor_like(
201+
const ValueRef vref,
202+
const api::StorageType storage_type,
203+
const api::GPUMemoryLayout memory_layout);
204+
205+
/*
206+
* Add a `vTensor` value to the graph with the properties of `vref`. The
207+
* suggested storage type will be used to construct the `vTensor`.
208+
*/
209+
ValueRef add_tensor_like(
210+
const ValueRef vref,
211+
const api::GPUMemoryLayout memory_layout);
212+
213213
/*
214214
* Add a `TensorRef` value to the graph with the specific properties. A
215215
* `TensorRef` is a reference to a `vTensor` whose data is stored in an

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ void resize_max_pool2d_node(
3535
new_out_sizes.at(ndim - 3) = self.sizes().at(ndim - 3);
3636

3737
// Height, Width
38-
const auto hw_sizes = calc_hw_out_sizes(
38+
const auto new_out_sizes_hw = calc_out_sizes_hw(
3939
*graph,
4040
self.sizes(),
4141
extra_args[0],
4242
extra_args[1],
4343
extra_args[2],
4444
extra_args[3],
4545
extra_args[4]);
46-
new_out_sizes.at(ndim - 2) = hw_sizes.at(0);
47-
new_out_sizes.at(ndim - 1) = hw_sizes.at(1);
46+
new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
47+
new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
4848

4949
out.virtual_resize(new_out_sizes);
5050
indices.virtual_resize(new_out_sizes);

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ValueRef prepack(
6363
ComputeGraph& graph,
6464
const ValueRef vref,
6565
const api::GPUMemoryLayout layout) {
66-
ValueRef v = graph.add_tensor(graph.get_val(vref).toTensorRef(), layout);
66+
ValueRef v = graph.add_tensor_like(vref, layout);
6767
vTensor& t = graph.get_val(v).toTensor();
6868

6969
api::ShaderInfo shader = get_nchw_to_image_shader(t);

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
namespace vkcompute {
1212

13-
api::utils::ivec2
14-
make_ivec2_int_list(ComputeGraph& graph, ValueRef vref, const bool reverse) {
15-
return api::utils::make_ivec2(graph.get_val(vref).toIntList(), reverse);
13+
api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) {
14+
return api::utils::make_ivec2(
15+
graph.get_val(vref).toIntList(), /*reverse = */ true);
1616
}
1717

1818
KernelParams create_kernel_params(
@@ -22,30 +22,31 @@ KernelParams create_kernel_params(
2222
const ValueRef padding,
2323
const ValueRef dilation) {
2424
return {
25-
make_ivec2_int_list(graph, kernel_size, /*reverse=*/true),
26-
make_ivec2_int_list(graph, stride, /*reverse=*/true),
27-
make_ivec2_int_list(graph, padding, /*reverse=*/true),
28-
make_ivec2_int_list(graph, dilation, /*reverse=*/true),
25+
make_ivec2_from_list(graph, kernel_size),
26+
make_ivec2_from_list(graph, stride),
27+
make_ivec2_from_list(graph, padding),
28+
make_ivec2_from_list(graph, dilation),
2929
};
3030
}
3131

3232
int64_t calc_out_size(
3333
const int64_t in_size,
34-
const int64_t kernel,
34+
const int64_t kernel_size,
3535
const int64_t stride,
3636
const int64_t padding,
3737
const int64_t dilation,
3838
const bool ceil_mode) {
3939
int64_t c = ceil_mode ? stride - 1 : 0;
4040
int64_t out_size =
41-
(in_size + 2 * padding - dilation * (kernel - 1) - 1 + c) / stride + 1;
41+
(in_size + 2 * padding - dilation * (kernel_size - 1) - 1 + c) / stride +
42+
1;
4243
if (ceil_mode && (out_size - 1) * stride >= in_size + padding) {
4344
--out_size;
4445
}
4546
return out_size;
4647
}
4748

48-
std::vector<int64_t> calc_hw_out_sizes(
49+
std::vector<int64_t> calc_out_sizes_hw(
4950
ComputeGraph& graph,
5051
const std::vector<int64_t>& in_sizes,
5152
const ValueRef kernel_size,
@@ -56,30 +57,28 @@ std::vector<int64_t> calc_hw_out_sizes(
5657
const int64_t ndim = in_sizes.size();
5758
std::vector<int64_t> out_sizes(2);
5859

59-
const auto kernel_vec =
60-
make_ivec2_int_list(graph, kernel_size, /*reverse=*/false);
61-
const auto stride_vec = make_ivec2_int_list(graph, stride, /*reverse=*/false);
62-
const auto padding_vec =
63-
make_ivec2_int_list(graph, padding, /*reverse=*/false);
64-
const auto dilation_vec =
65-
make_ivec2_int_list(graph, dilation, /*reverse=*/false);
60+
const auto kernel_vec = make_ivec2_from_list(graph, kernel_size);
61+
const auto stride_vec = make_ivec2_from_list(graph, stride);
62+
const auto padding_vec = make_ivec2_from_list(graph, padding);
63+
const auto dilation_vec = make_ivec2_from_list(graph, dilation);
64+
const bool ceil_mode_val = graph.get_val(ceil_mode).toBool();
6665

6766
// Height
6867
out_sizes.at(0) = calc_out_size(
6968
in_sizes.at(ndim - 2),
70-
kernel_vec.data[0],
71-
stride_vec.data[0],
72-
padding_vec.data[0],
73-
dilation_vec.data[0],
74-
ceil_mode);
75-
// Width
76-
out_sizes.at(1) = calc_out_size(
77-
in_sizes.at(ndim - 1),
7869
kernel_vec.data[1],
7970
stride_vec.data[1],
8071
padding_vec.data[1],
8172
dilation_vec.data[1],
82-
ceil_mode);
73+
ceil_mode_val);
74+
// Width
75+
out_sizes.at(1) = calc_out_size(
76+
in_sizes.at(ndim - 1),
77+
kernel_vec.data[0],
78+
stride_vec.data[0],
79+
padding_vec.data[0],
80+
dilation_vec.data[0],
81+
ceil_mode_val);
8382

8483
VK_CHECK_COND(out_sizes.at(0) >= 1);
8584
VK_CHECK_COND(out_sizes.at(1) >= 1);

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,13 @@ KernelParams create_kernel_params(
3030
const ValueRef padding,
3131
const ValueRef dilation);
3232

33-
int64_t calc_out_size(
34-
const int64_t in_size,
35-
const int64_t kernel_size,
36-
const int64_t stride,
37-
const int64_t padding,
38-
const int64_t dilation,
39-
const bool ceil_mode);
40-
41-
std::vector<int64_t> calc_hw_out_sizes(
33+
std::vector<int64_t> calc_out_sizes_hw(
4234
ComputeGraph& graph,
4335
const std::vector<int64_t>& in_sizes,
4436
const ValueRef kernel_size,
4537
const ValueRef stride,
4638
const ValueRef padding,
4739
const ValueRef dilation,
4840
const ValueRef ceil_mode);
41+
4942
} // namespace vkcompute

0 commit comments

Comments
 (0)