Skip to content

Commit 931d8da

Browse files
committed
[ET-VK] Implement prepack nodes
Pull Request resolved: #6352 ## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) ghstack-source-id: 248799488
1 parent 06f44b7 commit 931d8da

File tree

9 files changed

+168
-43
lines changed

9 files changed

+168
-43
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ python_unittest(
2828
],
2929
)
3030

31+
runtime.python_library(
32+
name = "insert_prepack_nodes",
33+
srcs = ["insert_prepack_nodes.py"],
34+
visibility = [
35+
"//executorch/backends/...",
36+
],
37+
deps = [
38+
"//caffe2:torch",
39+
"//executorch/exir:pass_base",
40+
],
41+
)
42+
3143
runtime.python_library(
3244
name = "remove_local_scalar_dense",
3345
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -65,6 +77,7 @@ runtime.python_library(
6577
"//executorch/examples/...",
6678
],
6779
deps = [
80+
":insert_prepack_nodes",
6881
":int4_weight_only_quantizer",
6982
":remove_local_scalar_dense",
7083
]

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
12
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
23
VkInt4WeightOnlyQuantizer,
34
)
@@ -6,6 +7,7 @@
67
)
78

89
__all__ = [
10+
"insert_prepack_nodes",
911
"VkInt4WeightOnlyQuantizer",
1012
"RemoveLocalScalarDenseOpsTransform",
1113
]

backends/vulkan/_passes/custom_ops_defs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@
99
namespace = "et_vk"
1010
lib = torch.library.Library(namespace, "DEF")
1111

12+
#############
13+
## prepack ##
14+
#############
15+
16+
17+
def prepack_impl(x: torch.Tensor):
18+
return x
19+
20+
21+
name = "prepack"
22+
lib.define(f"{name}(Tensor x) -> Tensor")
23+
lib.impl(name, prepack_impl, "CompositeExplicitAutograd")
24+
prepack_op = getattr(getattr(torch.ops, namespace), name)
25+
1226
#####################
1327
## conv_with_clamp ##
1428
#####################
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import List
10+
11+
import executorch.backends.vulkan._passes.custom_ops_defs # noqa
12+
13+
import torch
14+
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from torch._export.utils import is_buffer, is_param
18+
from torch.export import ExportedProgram
19+
20+
USES_WEIGHTS: List[torch._ops.OpOverload] = [
21+
exir_ops.edge.aten.embedding.default,
22+
exir_ops.edge.aten.convolution.default,
23+
exir_ops.edge.et_vk.conv_with_clamp.default,
24+
exir_ops.edge.aten.linear.default,
25+
exir_ops.edge.aten._weight_int8pack_mm.default,
26+
exir_ops.edge.et_vk.linear_weight_int4.default,
27+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
28+
exir_ops.edge.aten.native_layer_norm.default,
29+
"llama::sdpa_with_kv_cache",
30+
]
31+
32+
33+
def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
34+
"""
35+
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
36+
is responsible for transferring the tensor data, which is serialized with the model,
37+
to a GPU tensor object during the prepacking stage of model execution.
38+
39+
Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
40+
prefer to handle prepacking within the operator. For these ops, the constant tensor
41+
data will be passed directly as an argument into the operator implementation.
42+
"""
43+
44+
def is_get_attr_node(node: torch.fx.Node) -> bool:
45+
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
46+
47+
def is_constant(node: torch.fx.Node) -> bool:
48+
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
49+
50+
def is_param_node(node: torch.fx.Node) -> bool:
51+
"""
52+
Check if the given node is a parameter within the exported program
53+
"""
54+
return (
55+
is_get_attr_node(node)
56+
or is_param(program, node)
57+
or is_buffer(program, node)
58+
or is_constant(node)
59+
)
60+
61+
def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
62+
if not is_param_node(node):
63+
return False
64+
65+
for user in node.users:
66+
if user.op == "call_function" and (
67+
# pyre-ignore [16]
68+
user.target in USES_WEIGHTS
69+
or user.target.name() in USES_WEIGHTS
70+
):
71+
return False
72+
73+
return True
74+
75+
for node in program.graph_module.graph.nodes:
76+
if not is_non_weight_param_tensor(node):
77+
continue
78+
79+
with program.graph_module.graph.inserting_after(node):
80+
prepack_node = program.graph_module.graph.create_node(
81+
"call_function",
82+
exir_ops.edge.et_vk.prepack.default,
83+
(node,),
84+
)
85+
prepack_node.meta["spec"] = node.meta["spec"]
86+
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
87+
# memory object. This pass must be executed AFTER the memory planning pass.
88+
prepack_node.meta["spec"].mem_obj_id = -1
89+
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
90+
91+
program.graph.eliminate_dead_code()
92+
return program

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ void add_binary_op_node(
5151
const ValueRef alpha,
5252
const ValueRef out,
5353
const std::string& op_name) {
54-
ValueRef arg1 = prepack_standard_like(graph, in1, out, true);
55-
ValueRef arg2 = prepack_standard_like(graph, in2, out, true);
56-
57-
vTensorPtr t_in1 = graph.get_tensor(arg1);
58-
vTensorPtr t_in2 = graph.get_tensor(arg2);
54+
vTensorPtr t_in1 = graph.get_tensor(in1);
55+
vTensorPtr t_in2 = graph.get_tensor(in2);
5956
vTensorPtr t_out = graph.get_tensor(out);
6057

6158
check_binary_op_args(*t_in1, *t_in2, *t_out);
@@ -81,7 +78,7 @@ void add_binary_op_node(
8178
graph.create_local_wg_size(out),
8279
// Inputs and Outputs
8380
{{out, vkapi::MemoryAccessType::WRITE},
84-
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
81+
{{in1, in2}, vkapi::MemoryAccessType::READ}},
8582
// Shader params buffers
8683
{t_out->sizes_ubo(),
8784
t_out->axis_map_ubo(),

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
911
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1012

1113
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
@@ -205,4 +207,12 @@ ValueRef prepack_direct_copy_buffer(
205207
return tensor;
206208
}
207209

210+
void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
211+
return add_prepack_standard_node(graph, args[0], args[1]);
212+
}
213+
214+
REGISTER_OPERATORS {
215+
VK_REGISTER_OP(et_vk.prepack.default, prepack_op);
216+
}
217+
208218
} // namespace vkcompute

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,11 @@ def __init__(self):
251251
self.weight = torch.rand(size=(2, 3), dtype=torch.float32)
252252

253253
def forward(self, x, y):
254-
z = torch.add(x, y, alpha=2)
255-
z = torch.add(x, y, alpha=3.14)
256-
z = z + x
257-
z = z + self.weight
258-
return z
254+
inter1 = torch.add(x, y, alpha=2)
255+
inter2 = torch.add(x, y, alpha=3.14)
256+
inter3 = inter1 * self.weight
257+
inter4 = inter2 * self.weight
258+
return inter4 - inter3
259259

260260
internal_data_module = InternalDataModule()
261261
sample_inputs = (

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
15201520
ValueRef c = graph.add_tensor(size_big, vkapi::kFloat);
15211521
ValueRef e = graph.add_tensor(size_big, vkapi::kFloat);
15221522

1523+
ValueRef w1_packed = graph.add_tensor(size_small, vkapi::kFloat);
1524+
ValueRef w2_packed = graph.add_tensor(size_small, vkapi::kFloat);
1525+
1526+
auto prepackFn = VK_GET_OP_FN("et_vk.prepack.default");
1527+
prepackFn(graph, {w1, w1_packed});
1528+
prepackFn(graph, {w2, w2_packed});
1529+
15231530
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
1524-
addFn(graph, {a.value, w1, kDummyValueRef, c});
1531+
addFn(graph, {a.value, w1_packed, kDummyValueRef, c});
15251532

15261533
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
1527-
mulFn(graph, {c, w2, e});
1534+
mulFn(graph, {c, w2_packed, e});
15281535

15291536
IOValueRef out = {};
15301537
out.value = e;
@@ -2597,8 +2604,7 @@ void test_binary_op(
25972604
std::vector<int64_t> sizes_big,
25982605
std::vector<int64_t> sizes_small,
25992606
vkapi::ScalarType dtype,
2600-
utils::GPUMemoryLayout memory_layout,
2601-
bool prepack = true) {
2607+
utils::GPUMemoryLayout memory_layout) {
26022608
GraphConfig config;
26032609
ComputeGraph graph(config);
26042610

@@ -2609,12 +2615,7 @@ void test_binary_op(
26092615
// Build graph
26102616

26112617
IOValueRef arg1 = graph.add_input_tensor(sizes_big, dtype, memory_layout);
2612-
2613-
if (prepack) {
2614-
arg2.value = arg2_w;
2615-
} else {
2616-
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);
2617-
}
2618+
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);
26182619

26192620
IOValueRef out;
26202621
out.value = graph.add_tensor(sizes_big, dtype, memory_layout);
@@ -2635,7 +2636,7 @@ void test_binary_op(
26352636

26362637
for (int i = 1; i < 4; i++) {
26372638
float val_arg1 = i + 1.5;
2638-
float val_arg2 = prepack ? 2.5f : i - 3.5;
2639+
float val_arg2 = i - 3.5;
26392640

26402641
float val_out = val_arg1 + val_arg2;
26412642
if (op_name == "sub") {
@@ -2648,21 +2649,14 @@ void test_binary_op(
26482649
val_out = val_arg1 / val_arg2;
26492650
}
26502651

2651-
if (prepack) {
2652-
execute_graph_and_check_output(graph, {val_arg1}, {val_out});
2653-
} else {
2654-
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
2655-
}
2652+
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
26562653
}
26572654
}
26582655

2659-
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2660-
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
2661-
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \
2662-
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \
2663-
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \
2664-
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \
2665-
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true)
2656+
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2657+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked) \
2658+
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked) \
2659+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked)
26662660

26672661
#define CALL_TEST_FN_FOR_W_PACKED(_) \
26682662
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
@@ -2677,15 +2671,15 @@ void test_binary_op(
26772671
_(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true)
26782672

26792673
TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
2680-
#define RUN_TESTS(dtype, storage, layout, prepack) \
2681-
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \
2682-
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \
2683-
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \
2684-
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout, prepack); \
2685-
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout, prepack); \
2686-
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout, prepack); \
2687-
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout, prepack); \
2688-
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout, prepack);
2674+
#define RUN_TESTS(dtype, storage, layout) \
2675+
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout); \
2676+
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout); \
2677+
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout); \
2678+
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout); \
2679+
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout); \
2680+
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout); \
2681+
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout); \
2682+
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout);
26892683

26902684
CALL_TEST_FN_FORALL_CONDITIONS(RUN_TESTS);
26912685

backends/vulkan/vulkan_preprocess.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2020

2121
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
22+
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
2223

2324
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2425
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -86,6 +87,8 @@ def preprocess( # noqa: C901
8687

8788
_copy_module(program.graph_module, new_gm)
8889

90+
program = insert_prepack_nodes(program)
91+
8992
graph_builder = VkGraphBuilder(
9093
program, DelegateMappingBuilder(generated_identifiers=True)
9194
)

0 commit comments

Comments
 (0)