Skip to content

Commit 810ea6e

Browse files
committed
[ET-VK] Migrate workgroup API for trivial cases
This will allow us to override local workgroup sizes with #4046. ## Before ``` vTensorPtr t_out = graph.get_tensor(out); api::utils::uvec3 global_size = t_out->image_extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); graph.execute_nodes().emplace_back(new ExecuteNode( ..., global_size, local_size, ..., ); ``` ## After ``` graph.execute_nodes().emplace_back(new ExecuteNode( ..., graph.create_global_wg_size(out), graph.create_local_wg_size(out), ..., ); ``` Note we do not migrate cases where the global size is nontrivial (MatMul, Linear, Conv1D, Repeat) or the image isn't a ValueRef (MaxPool2D, NativeLayerNorm). We should first align on an API design for those cases. Differential Revision: [D59011492](https://our.internmc.facebook.com/intern/diff/D59011492/) [ghstack-poisoned]
1 parent 34fd767 commit 810ea6e

File tree

18 files changed

+38
-98
lines changed

18 files changed

+38
-98
lines changed

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,15 @@ void add_arange_node(
8484

8585
vTensorPtr t_out = graph.get_tensor(out);
8686

87-
api::utils::uvec3 global_size = t_out->image_extents();
88-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
89-
9087
std::string kernel_name("arange");
9188
kernel_name.reserve(kShaderNameReserve);
92-
9389
add_dtype_suffix(kernel_name, *t_out);
9490

9591
graph.execute_nodes().emplace_back(new ExecuteNode(
9692
graph,
9793
VK_KERNEL_FROM_STR(kernel_name),
98-
global_size,
99-
local_size,
94+
graph.create_global_wg_size(out),
95+
graph.create_local_wg_size(out),
10096
// Inputs and Outputs
10197
{{out, api::MemoryAccessType::WRITE}},
10298
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,14 @@ void add_native_batch_norm_node(
7777
std::string kernel_name = "batchnorm";
7878
add_dtype_suffix(kernel_name, *t_out);
7979

80-
api::utils::uvec3 global_size = t_out->image_extents();
81-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
82-
8380
int32_t num_texel_per_batch =
8481
api::utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));
8582

8683
graph.execute_nodes().emplace_back(new ExecuteNode(
8784
graph,
8885
VK_KERNEL_FROM_STR(kernel_name),
89-
global_size,
90-
local_size,
86+
graph.create_global_wg_size(out_ref),
87+
graph.create_local_wg_size(out_ref),
9188
{{out_ref, api::MemoryAccessType::WRITE},
9289
{{in_ref, arg_weight, arg_bias, arg_mean, arg_var},
9390
api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ void add_binary_op_node(
6161

6262
check_binary_op_args(*t_in1, *t_in2, *t_out);
6363

64-
api::utils::uvec3 global_size = t_out->image_extents();
65-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
66-
6764
float alpha_val = 1.0f;
6865
// String is checked since floor_div passes in an unused string argument in
6966
// place of alpha
@@ -82,8 +79,8 @@ void add_binary_op_node(
8279
graph.execute_nodes().emplace_back(new ExecuteNode(
8380
graph,
8481
VK_KERNEL_FROM_STR(kernel_name),
85-
global_size,
86-
local_size,
82+
graph.create_global_wg_size(out),
83+
graph.create_local_wg_size(out),
8784
// Inputs and Outputs
8885
{{out, api::MemoryAccessType::WRITE},
8986
{{arg1, arg2}, api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@ void add_clone_node(
2525
std::string kernel_name = "clone";
2626
add_dtype_suffix(kernel_name, *t_out);
2727

28-
api::utils::uvec3 global_size = t_out->image_extents();
29-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
30-
3128
graph.execute_nodes().emplace_back(new ExecuteNode(
3229
graph,
3330
VK_KERNEL_FROM_STR(kernel_name),
34-
global_size,
35-
local_size,
31+
graph.create_global_wg_size(out),
32+
graph.create_local_wg_size(out),
3633
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
3734
{t_out->texture_limits_ubo()}));
3835
}

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ void add_copy_offset_node(
3232
kernel_name.reserve(kShaderNameReserve);
3333
add_dtype_suffix(kernel_name, *t_out);
3434

35-
uvec3 global_size = api::utils::make_uvec3(range);
36-
uvec3 local_size = adaptive_work_group_size(global_size);
37-
3835
const struct Block final {
3936
ivec3 range;
4037
int32_t unused0;
@@ -56,8 +53,8 @@ void add_copy_offset_node(
5653
graph.execute_nodes().emplace_back(new ExecuteNode(
5754
graph,
5855
VK_KERNEL_FROM_STR(kernel_name),
59-
global_size,
60-
local_size,
56+
graph.create_global_wg_size(out),
57+
graph.create_local_wg_size(out),
6158
// Inputs and Outputs
6259
{
6360
{out, api::MemoryAccessType::WRITE},
@@ -141,7 +138,6 @@ void add_copy_channel_offset_node(
141138
api::utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
142139
api::utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
143140
api::utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
144-
145141
uvec3 local_size = adaptive_work_group_size(global_size);
146142

147143
const struct Block final {

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,11 @@ void add_embedding_node(
4141
kernel_name.reserve(kShaderNameReserve);
4242
add_dtype_suffix(kernel_name, *t_out);
4343

44-
api::utils::uvec3 global_size = t_out->image_extents();
45-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
46-
4744
graph.execute_nodes().emplace_back(new ExecuteNode(
4845
graph,
4946
VK_KERNEL_FROM_STR(kernel_name),
50-
global_size,
51-
local_size,
47+
graph.create_global_wg_size(out),
48+
graph.create_local_wg_size(out),
5249
{{out, api::MemoryAccessType::WRITE},
5350
{{in, weight}, api::MemoryAccessType::READ}},
5451
{t_out->sizes_ubo()}));

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ void add_full_node(
3939
float fill_value_val = graph.extract_scalar<float>(fill_value);
4040
vTensorPtr t_out = graph.get_tensor(out);
4141

42-
api::utils::uvec3 global_size = t_out->image_extents();
43-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
44-
4542
std::string kernel_name("full");
4643
kernel_name.reserve(kShaderNameReserve);
4744

@@ -50,8 +47,8 @@ void add_full_node(
5047
graph.execute_nodes().emplace_back(new ExecuteNode(
5148
graph,
5249
VK_KERNEL_FROM_STR(kernel_name),
53-
global_size,
54-
local_size,
50+
graph.create_global_wg_size(out),
51+
graph.create_local_wg_size(out),
5552
// Inputs and Outputs
5653
{{out, api::MemoryAccessType::WRITE}},
5754
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,11 @@ void add_index_select_channel_node(
4141
kernel_name.reserve(kShaderNameReserve);
4242
add_dtype_suffix(kernel_name, *t_out);
4343

44-
api::utils::uvec3 global_size = t_out->image_extents();
45-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
46-
4744
graph.execute_nodes().emplace_back(new ExecuteNode(
4845
graph,
4946
VK_KERNEL_FROM_STR(kernel_name),
50-
global_size,
51-
local_size,
47+
graph.create_global_wg_size(out),
48+
graph.create_local_wg_size(out),
5249
{{out, api::MemoryAccessType::WRITE},
5350
{{in, idx}, api::MemoryAccessType::READ}},
5451
{t_out->sizes_ubo(), t_in->sizes_ubo()}));
@@ -93,14 +90,11 @@ void add_index_select_node(
9390
kernel_name.reserve(kShaderNameReserve);
9491
add_dtype_suffix(kernel_name, *t_out);
9592

96-
api::utils::uvec3 global_size = t_out->image_extents();
97-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
98-
9993
graph.execute_nodes().emplace_back(new ExecuteNode(
10094
graph,
10195
VK_KERNEL_FROM_STR(kernel_name),
102-
global_size,
103-
local_size,
96+
graph.create_global_wg_size(out),
97+
graph.create_local_wg_size(out),
10498
{{out, api::MemoryAccessType::WRITE},
10599
{{in, idx}, api::MemoryAccessType::READ}},
106100
{t_out->sizes_ubo(), graph.create_params_buffer(params)}));

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,6 @@ void add_addmm_naive_node(
101101
ValueRef self = prepack_if_tensor_ref(graph, self_data, api::kWidthPacked);
102102
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked);
103103

104-
api::utils::uvec3 global_size = graph.image_extents_of(out);
105-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
106-
107104
std::string kernel_name =
108105
graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
109106
kernel_name.reserve(kShaderNameReserve);
@@ -114,8 +111,8 @@ void add_addmm_naive_node(
114111
graph.execute_nodes().emplace_back(new ExecuteNode(
115112
graph,
116113
VK_KERNEL_FROM_STR(kernel_name),
117-
global_size,
118-
local_size,
114+
graph.create_global_wg_size(out),
115+
graph.create_local_wg_size(out),
119116
// Inputs and Outputs
120117
{{out, api::MemoryAccessType::WRITE},
121118
{{mat1, mat2, self}, api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ void add_matmul_naive_node(
7272
const ValueRef mat2_is_transposed) {
7373
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked);
7474

75-
api::utils::uvec3 global_size = graph.image_extents_of(out);
76-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
77-
7875
std::string kernel_name = graph.get_bool(mat2_is_transposed)
7976
? "matmul_transposed_naive"
8077
: "matmul_naive";
@@ -86,8 +83,8 @@ void add_matmul_naive_node(
8683
graph.execute_nodes().emplace_back(new ExecuteNode(
8784
graph,
8885
VK_KERNEL_FROM_STR(kernel_name),
89-
global_size,
90-
local_size,
86+
graph.create_global_wg_size(out),
87+
graph.create_local_wg_size(out),
9188
// Inputs and Outputs
9289
{{out, api::MemoryAccessType::WRITE},
9390
{{mat1, mat2}, api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/Pad.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ void add_constant_pad_nd_node(
6565
vTensorPtr t_in = graph.get_tensor(in);
6666
vTensorPtr t_out = graph.get_tensor(out);
6767

68-
api::utils::uvec3 global_size = t_out->image_extents();
69-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
70-
7168
std::string kernel_name = "";
7269
PadParam pad_param = creat_pad_param(*pad_vec);
7370

@@ -84,8 +81,8 @@ void add_constant_pad_nd_node(
8481
graph.execute_nodes().emplace_back(new ExecuteNode(
8582
graph,
8683
VK_KERNEL_FROM_STR(kernel_name),
87-
global_size,
88-
local_size,
84+
graph.create_global_wg_size(out),
85+
graph.create_local_wg_size(out),
8986
// Inputs and Outputs
9087
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
9188
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Permute.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,11 @@ void add_permute_node(
8585
{out_c_aligned, in_c_aligned},
8686
};
8787

88-
api::utils::uvec3 global_size = t_out->image_extents();
89-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
90-
9188
graph.execute_nodes().emplace_back(new ExecuteNode(
9289
graph,
9390
VK_KERNEL_FROM_STR(kernel_name),
94-
global_size,
95-
local_size,
91+
graph.create_global_wg_size(out),
92+
graph.create_local_wg_size(out),
9693
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
9794
{t_out->texture_limits_ubo(),
9895
t_out->sizes_ubo(),

backends/vulkan/runtime/graph/ops/impl/Select.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,12 @@ void add_select_int_node(
102102
kernel_name.reserve(kShaderNameReserve);
103103
add_dtype_suffix(kernel_name, *t_out);
104104

105-
api::utils::uvec3 global_size = t_out->image_extents();
106-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
107-
108105
// TODO: add resizing to support dynamic shapes.
109106
graph.execute_nodes().emplace_back(new ExecuteNode(
110107
graph,
111108
VK_KERNEL_FROM_STR(kernel_name),
112-
global_size,
113-
local_size,
109+
graph.create_global_wg_size(out),
110+
graph.create_local_wg_size(out),
114111
// Inputs and Outputs
115112
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
116113
// Parameter buffers

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ void add_slice_tensor_out_node(
8080
kernel_name.reserve(kShaderNameReserve);
8181
add_dtype_suffix(kernel_name, *t_out);
8282

83-
api::utils::uvec3 global_size = t_out->image_extents();
84-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
85-
8683
const struct Block final {
8784
int offset;
8885
int step;
@@ -94,8 +91,8 @@ void add_slice_tensor_out_node(
9491
graph.execute_nodes().emplace_back(new ExecuteNode(
9592
graph,
9693
VK_KERNEL_FROM_STR(kernel_name),
97-
global_size,
98-
local_size,
94+
graph.create_global_wg_size(out),
95+
graph.create_local_wg_size(out),
9996
{{out, api::MemoryAccessType::WRITE},
10097
{in, api::MemoryAccessType::READ}},
10198
{t_out->sizes_ubo(),

backends/vulkan/runtime/graph/ops/impl/Softmax.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ void add_softmax_node(
4343
softmax_dim = normalize(softmax_dim, in_dim);
4444

4545
vTensorPtr t_out = graph.get_tensor(out);
46-
uvec3 global_size = t_out->image_extents();
4746

4847
api::ShaderInfo shader_descriptor;
4948
std::string kernel_name = in_dim - softmax_dim == 3
@@ -55,14 +54,12 @@ void add_softmax_node(
5554
kernel_name = "log_" + kernel_name;
5655
}
5756

58-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
59-
6057
graph.execute_nodes().emplace_back(new ExecuteNode(
6158
graph,
6259
// shader_descriptor,
6360
VK_KERNEL_FROM_STR(kernel_name),
64-
global_size,
65-
local_size,
61+
graph.create_global_wg_size(out),
62+
graph.create_local_wg_size(out),
6663
// Inputs and Outputs
6764
{{out, api::MemoryAccessType::WRITE},
6865
{in_arg, api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/Sum.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,18 @@ void add_sum_dim_node(
6868
in_dim > 2 ? static_cast<int32_t>(t_input->sizes()[in_dim - 3]) : 1;
6969
uint32_t dim_size = t_input->sizes()[dim];
7070

71-
api::utils::uvec3 global_size = t_out->image_extents();
72-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
73-
7471
std::string kernel_name("sum_dim");
7572
kernel_name.reserve(kShaderNameReserve);
7673
if (keepdim) {
7774
kernel_name += "_keepdim";
7875
}
79-
8076
add_dtype_suffix(kernel_name, *t_out);
8177

8278
graph.execute_nodes().emplace_back(new ExecuteNode(
8379
graph,
8480
VK_KERNEL_FROM_STR(kernel_name),
85-
global_size,
86-
local_size,
81+
graph.create_global_wg_size(out),
82+
graph.create_local_wg_size(out),
8783
// Inputs and Outputs
8884
{{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}},
8985
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Upsample.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,16 @@ void add_upsample_nearest2d_node(
6969
}
7070

7171
vTensorPtr t_out = graph.get_tensor(out);
72-
api::utils::uvec3 global_size = t_out->image_extents();
73-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
7472

7573
std::string kernel_name("upsample_nearest2d");
7674
kernel_name.reserve(kShaderNameReserve);
77-
7875
add_dtype_suffix(kernel_name, *t_out);
7976

8077
graph.execute_nodes().emplace_back(new ExecuteNode(
8178
graph,
8279
VK_KERNEL_FROM_STR(kernel_name),
83-
global_size,
84-
local_size,
80+
graph.create_global_wg_size(out),
81+
graph.create_local_wg_size(out),
8582
// Inputs and Outputs
8683
{{out, api::MemoryAccessType::WRITE},
8784
{arg_in, api::MemoryAccessType::READ}},

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,11 @@ void add_view_node(
6565
kernel_name.reserve(kShaderNameReserve);
6666
add_dtype_suffix(kernel_name, *t_out);
6767

68-
api::utils::uvec3 global_size = t_out->image_extents();
69-
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
70-
7168
graph.execute_nodes().emplace_back(new ExecuteNode(
7269
graph,
7370
VK_KERNEL_FROM_STR(kernel_name),
74-
global_size,
75-
local_size,
71+
graph.create_global_wg_size(out),
72+
graph.create_local_wg_size(out),
7673
// Inputs and Outputs
7774
{{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}},
7875
// Parameter Buffers

0 commit comments

Comments
 (0)