Skip to content

Commit a7e7664

Browse files
committed
[ET-VK] Clean up prepack API
Pull Request resolved: #6331 ## Context As title, revamp the prepacking API: * Make the naming more clear; i.e. `prepack_if_tensor_ref` to `prepack_standard` to disambiguate the packing logic that will used. * Instead of passing through the `v` argument if it is a `Tensor` by default, this functionality must be toggled via the `passthrough` argument. The goal is to encourage developers to be more explicit about what types they expect the operator arguments to be. * Consolidate API surface and reduce the number of overloads Past the API changes, I have also removed a bunch of unnecessary calls to `prepack_if_tensor_ref` throughout the operator implementations. The most common cases were calling it on an input tensor which is not necessary. ## The "big picture" for prepacking `TensorRef` objects and prepacking are used whenever we are dealing with a Tensor whose data is serialized with the model. However, these "serialized tensors" all belong to one of two categories * Weight/biases: trained weights and biases that act as the state for a i.e. Convolutional or Linear layer. These tensors are used only within the `nn.Module` that they belong to * Persistent tensors: tensors whose data just happen to be invariant to the inputs, and their data can be serialized with the model itself. They are treated as regular tensors and may be used in several operators throughout the model. One example is `freqs_sin` and `freqs_cos` in Llama models which are used to calculate rotary positional encodings For weights and biases, the way that the serialized data should be packed may be dependent on the operator it is used in. However, for persistent tensors they must be packed with the "standard" staging to tensor algorithm since they are the same as regular tensors. While it is well known which operators expect weight tensors. However, persistent tensors are tricky because they can be used as an argument to any operator. This would mean that every operator needs to account for the possibility that one of their inputs will be a serialized tensor. This is undesirable because it adds an additional layer of indirection when processing operator inputs on top of the fact that every argument is actually a reference to a`Value` object in the graph, which itself is a wrapper. It also makes things complicated for the operator developer. Another downside is that persistent tensors will be packed multiple times, once by each operator that uses it. To address this, I plan to handle persistent tensors at export time by inserting a `prepack()` operator for them which will cause operators that use the serialized tensor to see a Tensor object instead of a TensorRef object. This will make it so that the only operators that should expect to prepack an argument are tensors that expect a weight argument, and also avoid packing persistent tensors multiple times. Differential Revision: [D64550560](https://our.internmc.facebook.com/intern/diff/D64550560/) ghstack-source-id: 248784728
1 parent b1c94ab commit a7e7664

19 files changed

+198
-159
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ ValueRef ComputeGraph::add_tensor_like(
285285
ValueRef ComputeGraph::add_tensor_like(
286286
const ValueRef idx,
287287
const utils::GPUMemoryLayout memory_layout) {
288-
return add_tensor(sizes_of(idx), dtype_of(idx), memory_layout);
288+
return add_tensor(
289+
sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout);
289290
}
290291

291292
ValueRef ComputeGraph::add_tensor(

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
namespace vkcompute {
2020

21-
ValueRef prepack_arg(
21+
ValueRef check_and_prepack_arg(
2222
ComputeGraph& graph,
2323
ValueRef arg_ref,
24+
const utils::StorageType stype,
2425
int64_t num_channels,
2526
const std::string& debug_name) {
2627
VK_CHECK_COND(
@@ -33,7 +34,7 @@ ValueRef prepack_arg(
3334
// batch_norm's param are broadcasted on the channel dimension.
3435
// In this implementation, we pack the weights along the x dimension, and
3536
// in the shader, we lookup using the along the x.
36-
return prepack_if_tensor_ref(graph, arg_ref, utils::kWidthPacked);
37+
return prepack_standard(graph, arg_ref, stype, utils::kWidthPacked);
3738
}
3839

3940
void add_native_batch_norm_node(
@@ -51,22 +52,26 @@ void add_native_batch_norm_node(
5152
VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor");
5253
VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor");
5354

55+
// Only the first element of the return value is propagated. The remaining 2
56+
// elements are zero-size dummy tensor.
57+
ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0);
58+
59+
utils::StorageType stype = graph.storage_type_of(out_ref);
60+
5461
int64_t num_channels = dim_at<kChannel4D>(in_sizes);
5562

56-
ValueRef arg_weight = prepack_arg(graph, weight_ref, num_channels, "weight");
57-
ValueRef arg_bias = prepack_arg(graph, bias_ref, num_channels, "bias");
58-
ValueRef arg_mean = prepack_arg(graph, mean_ref, num_channels, "mean");
59-
ValueRef arg_var = prepack_arg(graph, var_ref, num_channels, "var");
63+
ValueRef arg_weight =
64+
check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight");
65+
ValueRef arg_bias =
66+
check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias");
67+
ValueRef arg_mean =
68+
check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean");
69+
ValueRef arg_var =
70+
check_and_prepack_arg(graph, var_ref, stype, num_channels, "var");
6071
float epsilon = graph.extract_scalar<float>(eps_ref);
6172

6273
vTensorPtr t_in = graph.get_tensor(in_ref);
6374

64-
// Only the first element of the return value is propagated. The remaining 2
65-
// elements are zero-size dummy tensor.
66-
const auto out_tuple_val = graph.get_value_list(out_tuple_ref);
67-
68-
ValueRef out_ref = out_tuple_val->at(0);
69-
7075
VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref");
7176
vTensorPtr t_out = graph.get_tensor(out_ref);
7277

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ void add_binary_op_node(
5151
const ValueRef alpha,
5252
const ValueRef out,
5353
const std::string& op_name) {
54-
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
55-
ValueRef arg2 =
56-
prepack_if_tensor_ref(graph, in2, graph.estimate_memory_layout_of(arg1));
54+
ValueRef arg1 = prepack_standard_like(graph, in1, out, true);
55+
ValueRef arg2 = prepack_standard_like(graph, in2, out, true);
5756

5857
vTensorPtr t_in1 = graph.get_tensor(arg1);
5958
vTensorPtr t_in2 = graph.get_tensor(arg2);

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ utils::uvec3 create_conv2d_global_wg_size(
304304
void add_conv2d_node(
305305
ComputeGraph& graph,
306306
const ValueRef in,
307-
const ValueRef weight,
307+
const ValueRef weight_data,
308308
const ValueRef bias,
309309
const ValueRef stride,
310310
const ValueRef padding,
@@ -330,19 +330,18 @@ void add_conv2d_node(
330330
const int64_t groups_val = graph.get_int(groups);
331331

332332
const Conv2dMethod method =
333-
get_conv2d_method(graph, weight, groups_val, transposed_val);
333+
get_conv2d_method(graph, weight_data, groups_val, transposed_val);
334334

335-
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
336-
ValueRef arg_weight = prepack_weights(graph, weight, method);
335+
ValueRef arg_weight = prepack_weights(graph, weight_data, method);
337336
ValueRef arg_bias = prepack_biases(
338337
graph,
339338
bias,
340-
weight,
339+
weight_data,
341340
transposed_val,
342341
/* storage_type = */ utils::kTexture2D,
343342
/* memory_layout = */ utils::kWidthPacked);
344343

345-
vTensorPtr t_in = graph.get_tensor(arg_in);
344+
vTensorPtr t_in = graph.get_tensor(in);
346345
vTensorPtr t_out = graph.get_tensor(out);
347346
if (t_in->sizes().at(0) > 1) {
348347
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
@@ -351,20 +350,25 @@ void add_conv2d_node(
351350

352351
Kernel2dParams kernel_params = create_kernel2d_params(
353352
graph,
354-
weight,
353+
weight_data,
355354
/*kernel_size_only = */ false,
356355
stride,
357356
padding,
358357
dilation);
359358
Conv2dParams extra_params =
360-
create_conv2d_params(graph, weight, kernel_params, transposed_val);
359+
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
361360

362361
OutputParams out_params = {out_min_val, out_max_val};
363362

364363
check_conv2d_params(kernel_params, transposed_val);
365364

366365
vkapi::ShaderInfo shader = get_conv2d_shader(
367-
graph, *t_out, /*prepack_weights = */ false, method, weight, clamp_out);
366+
graph,
367+
*t_out,
368+
/*prepack_weights = */ false,
369+
method,
370+
weight_data,
371+
clamp_out);
368372

369373
graph.execute_nodes().emplace_back(new DispatchNode(
370374
graph,
@@ -373,7 +377,7 @@ void add_conv2d_node(
373377
graph.create_local_wg_size(out),
374378
// Inputs and Outputs
375379
{{out, vkapi::MemoryAccessType::WRITE},
376-
{{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
380+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
377381
// Shader params buffers
378382
{
379383
t_out->logical_limits_ubo(),
@@ -386,7 +390,7 @@ void add_conv2d_node(
386390
{},
387391
// Resizing Logic
388392
resize_conv2d_node,
389-
{weight, stride, padding, dilation, transposed, output_padding}));
393+
{weight_data, stride, padding, dilation, transposed, output_padding}));
390394
}
391395

392396
void add_conv1d_node(
@@ -402,9 +406,8 @@ void add_conv1d_node(
402406
const ValueRef out_max,
403407
const ValueRef out,
404408
const bool clamp_out) {
405-
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
406-
ValueRef arg_weight =
407-
prepack_if_tensor_ref(graph, weight, utils::kWidthPacked);
409+
ValueRef arg_weight = prepack_standard(
410+
graph, weight, graph.storage_type_of(out), utils::kWidthPacked);
408411
ValueRef arg_bias = prepack_biases(
409412
graph,
410413
bias,
@@ -422,7 +425,7 @@ void add_conv1d_node(
422425
out_max_val = graph.extract_scalar<float>(out_max);
423426
}
424427

425-
vTensorPtr t_in = graph.get_tensor(arg_in);
428+
vTensorPtr t_in = graph.get_tensor(in);
426429
vTensorPtr t_weight = graph.get_tensor(arg_weight);
427430
vTensorPtr t_bias = graph.get_tensor(arg_bias);
428431
vTensorPtr t_out = graph.get_tensor(out);
@@ -471,7 +474,7 @@ void add_conv1d_node(
471474
local_size,
472475
// Inputs and Outputs
473476
{{out, vkapi::MemoryAccessType::WRITE},
474-
{{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
477+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
475478
// Shader params buffers
476479
{
477480
t_out->logical_limits_ubo(),

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ void add_embedding_node(
5757
}
5858

5959
void embedding(ComputeGraph& graph, const std::vector<ValueRef>& args) {
60-
ValueRef weight = prepack_if_tensor_ref(graph, args[0]);
61-
ValueRef in = prepack_if_tensor_ref(graph, args[1]);
60+
ValueRef in = args[1];
6261
ValueRef out = args[5];
62+
ValueRef weight = prepack_standard_like(graph, args[0], out);
6363

6464
add_embedding_node(graph, weight, in, out);
6565
}

backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) {
108108
}
109109

110110
void index_select(ComputeGraph& graph, const std::vector<ValueRef>& args) {
111-
ValueRef in = prepack_if_tensor_ref(graph, args[0]);
111+
ValueRef in = args[0];
112112
ValueRef dim_ref = args[1];
113-
ValueRef idx = prepack_if_tensor_ref(graph, args[2]);
113+
ValueRef idx = args[2];
114114
ValueRef out = args[3];
115115

116116
const int64_t dim_idx = get_dim_idx(graph, in, dim_ref);

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ void add_addmm_naive_node(
9494
const ValueRef out,
9595
const Params& params,
9696
const ValueRef mat2_is_transposed) {
97-
ValueRef self = prepack_if_tensor_ref(graph, self_data, utils::kWidthPacked);
98-
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
97+
utils::StorageType stype = graph.storage_type_of(out);
98+
ValueRef self =
99+
prepack_standard(graph, self_data, stype, utils::kWidthPacked, true);
100+
ValueRef mat2 =
101+
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
99102

100103
std::string kernel_name =
101104
graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
@@ -145,9 +148,11 @@ void add_addmm_optimized_node(
145148
const ValueRef out,
146149
const Params& params,
147150
const ValueRef mat2_is_transposed) {
151+
utils::StorageType stype = graph.storage_type_of(out);
148152
ValueRef self =
149-
prepack_if_tensor_ref(graph, self_data, utils::kChannelsPacked);
150-
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
153+
prepack_standard(graph, self_data, stype, utils::kChannelsPacked, true);
154+
ValueRef mat2 =
155+
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
151156

152157
// Ensure mat1 is width packed
153158
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
@@ -276,8 +281,8 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
276281
ValueRef weight_data = args.at(1);
277282
ValueRef bias = args.at(2);
278283
ValueRef out = args.at(3);
279-
ValueRef weight =
280-
prepack_if_tensor_ref(graph, weight_data, utils::kWidthPacked);
284+
ValueRef weight = prepack_standard(
285+
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
281286
ValueRef mat2_is_transposed = graph.add_scalar(true);
282287
if (graph.val_is_none(bias)) {
283288
return add_matmul_node(graph, input, weight, out, mat2_is_transposed);

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ void add_matmul_naive_buffer_node(
6262
const ValueRef mat2_data,
6363
const ValueRef out,
6464
const ValueRef mat2_is_transposed) {
65-
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
65+
ValueRef mat2 = prepack_standard(
66+
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
6667

6768
std::string kernel_name = "matmul_naive_buffer";
6869
add_dtype_suffix(kernel_name, graph.dtype_of(out));
@@ -103,7 +104,8 @@ void add_matmul_naive_texture3d_node(
103104
const ValueRef mat2_data,
104105
const ValueRef out,
105106
const ValueRef mat2_is_transposed) {
106-
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
107+
ValueRef mat2 = prepack_standard(
108+
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
107109

108110
std::string kernel_name = graph.get_bool(mat2_is_transposed)
109111
? "matmul_transposed_naive"
@@ -146,7 +148,8 @@ void add_matmul_optimized_node(
146148
const ValueRef mat2_data,
147149
const ValueRef out,
148150
const ValueRef mat2_is_transposed) {
149-
ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked);
151+
ValueRef mat2 = prepack_standard(
152+
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
150153

151154
// Ensure mat1 is width packed
152155
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ void add_native_layer_norm_node(
5757
ComputeGraph& graph,
5858
const ValueRef in,
5959
const ValueRef normalized_shape,
60-
const ValueRef weight,
61-
const ValueRef bias,
60+
const ValueRef weight_data,
61+
const ValueRef bias_data,
6262
const ValueRef eps,
6363
const ValueRef out) {
6464
const auto normalized_shape_dim =
@@ -67,19 +67,16 @@ void add_native_layer_norm_node(
6767
VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
6868
}
6969

70-
if (graph.val_is_none(weight)) {
70+
if (graph.val_is_none(weight_data)) {
7171
VK_THROW("native_layer_norm requires weight to be non-None");
7272
}
7373

74-
if (graph.val_is_none(bias)) {
74+
if (graph.val_is_none(bias_data)) {
7575
VK_THROW("native_layer_norm requires bias to be non-None");
7676
}
7777

78-
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
79-
ValueRef arg_weight = prepack_if_tensor_ref(
80-
graph, weight, graph.estimate_memory_layout_of(arg_in));
81-
ValueRef arg_bias = prepack_if_tensor_ref(
82-
graph, bias, graph.estimate_memory_layout_of(arg_in));
78+
ValueRef arg_weight = prepack_standard_like(graph, weight_data, in);
79+
ValueRef arg_bias = prepack_standard_like(graph, bias_data, in);
8380

8481
const auto out_val = graph.get_value_list(out);
8582
vTensorPtr t_out = graph.get_tensor(out_val->at(0));
@@ -107,7 +104,7 @@ void add_native_layer_norm_node(
107104
// Inputs and Outputs
108105
{{{out_val->at(0), out_val->at(1), out_val->at(2)},
109106
vkapi::MemoryAccessType::WRITE},
110-
{{arg_in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
107+
{{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
111108
// Shader params buffers
112109
{t_out->logical_limits_ubo(),
113110
t_out->sizes_ubo(),

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ void add_max_pool2d_node(
7171
const ValueRef dilation,
7272
const ValueRef ceil_mode,
7373
const ValueRef out) {
74-
ValueRef arg = prepack_if_tensor_ref(graph, in);
75-
vTensorPtr t_in = graph.get_tensor(arg);
74+
vTensorPtr t_in = graph.get_tensor(in);
7675

7776
const auto out_val = graph.get_value_list(out);
7877
vTensorPtr t_out = graph.get_tensor(out_val->at(0));
@@ -100,7 +99,7 @@ void add_max_pool2d_node(
10099
local_size,
101100
// Inputs and Outputs
102101
{{{out_val->at(0), out_val->at(1)}, vkapi::MemoryAccessType::WRITE},
103-
{arg, vkapi::MemoryAccessType::READ}},
102+
{in, vkapi::MemoryAccessType::READ}},
104103
// Shader params buffers
105104
{
106105
t_out->logical_limits_ubo(),
@@ -149,8 +148,7 @@ void add_avg_pool2d_node(
149148
const ValueRef count_include_pad,
150149
const ValueRef divisor_override,
151150
const ValueRef out) {
152-
ValueRef arg = prepack_if_tensor_ref(graph, in);
153-
vTensorPtr t_in = graph.get_tensor(arg);
151+
vTensorPtr t_in = graph.get_tensor(in);
154152
vTensorPtr t_out = graph.get_tensor(out);
155153

156154
check_pool2d_args(*t_in, *t_out);
@@ -174,7 +172,7 @@ void add_avg_pool2d_node(
174172
local_size,
175173
// Inputs and Outputs
176174
{{out, vkapi::MemoryAccessType::WRITE},
177-
{arg, vkapi::MemoryAccessType::READ}},
175+
{in, vkapi::MemoryAccessType::READ}},
178176
// Shader params buffers
179177
{t_out->logical_limits_ubo(),
180178
t_in->sizes_ubo(),

0 commit comments

Comments
 (0)