Skip to content

Commit a130902

Browse files
committed
[ET-VK] Implement prepack nodes
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) ghstack-source-id: 248785585 Pull Request resolved: #6352
1 parent a7e7664 commit a130902

File tree

8 files changed

+144
-10
lines changed

8 files changed

+144
-10
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,18 @@ python_unittest(
2828
],
2929
)
3030

31+
runtime.python_library(
32+
name = "insert_prepack_nodes",
33+
srcs = ["insert_prepack_nodes.py"],
34+
visibility = [
35+
"//executorch/backends/...",
36+
],
37+
deps = [
38+
"//caffe2:torch",
39+
"//executorch/exir:pass_base",
40+
],
41+
)
42+
3143
runtime.python_library(
3244
name = "remove_local_scalar_dense",
3345
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -65,6 +77,7 @@ runtime.python_library(
6577
"//executorch/examples/...",
6678
],
6779
deps = [
80+
":insert_prepack_nodes",
6881
":int4_weight_only_quantizer",
6982
":remove_local_scalar_dense",
7083
]

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
12
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
23
VkInt4WeightOnlyQuantizer,
34
)
@@ -6,6 +7,7 @@
67
)
78

89
__all__ = [
10+
"insert_prepack_nodes",
911
"VkInt4WeightOnlyQuantizer",
1012
"RemoveLocalScalarDenseOpsTransform",
1113
]

backends/vulkan/_passes/custom_ops_defs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@
99
namespace = "et_vk"
1010
lib = torch.library.Library(namespace, "DEF")
1111

12+
#############
13+
## prepack ##
14+
#############
15+
16+
17+
def prepack_impl(x: torch.Tensor):
18+
return x
19+
20+
21+
name = "prepack"
22+
lib.define(f"{name}(Tensor x) -> Tensor")
23+
lib.impl(name, prepack_impl, "CompositeExplicitAutograd")
24+
prepack_op = getattr(getattr(torch.ops, namespace), name)
25+
1226
#####################
1327
## conv_with_clamp ##
1428
#####################
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import List
10+
11+
import executorch.backends.vulkan._passes.custom_ops_defs # noqa
12+
13+
import torch
14+
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from torch._export.utils import is_buffer, is_param
18+
from torch.export import ExportedProgram
19+
20+
USES_WEIGHTS: List[torch._ops.OpOverload] = [
21+
exir_ops.edge.aten.embedding.default,
22+
exir_ops.edge.aten.convolution.default,
23+
exir_ops.edge.et_vk.conv_with_clamp.default,
24+
exir_ops.edge.aten.linear.default,
25+
exir_ops.edge.aten._weight_int8pack_mm.default,
26+
exir_ops.edge.et_vk.linear_weight_int4.default,
27+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
28+
exir_ops.edge.aten.native_layer_norm.default,
29+
"llama::sdpa_with_kv_cache",
30+
]
31+
32+
33+
def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
34+
"""
35+
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
36+
is responsible for transferring the tensor data, which is serialized with the model,
37+
to a GPU tensor object during the prepacking stage of model execution.
38+
39+
Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
40+
prefer to handle prepacking within the operator. For these ops, the constant tensor
41+
data will be passed directly as an argument into the operator implementation.
42+
"""
43+
44+
def is_get_attr_node(node: torch.fx.Node) -> bool:
45+
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
46+
47+
def is_constant(node: torch.fx.Node) -> bool:
48+
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
49+
50+
def is_param_node(node: torch.fx.Node) -> bool:
51+
"""
52+
Check if the given node is a parameter within the exported program
53+
"""
54+
return (
55+
is_get_attr_node(node)
56+
or is_param(program, node)
57+
or is_buffer(program, node)
58+
or is_constant(node)
59+
)
60+
61+
def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
62+
if not is_param_node(node):
63+
return False
64+
65+
for user in node.users:
66+
if user.op == "call_function" and (
67+
# pyre-ignore [16]
68+
user.target in USES_WEIGHTS
69+
or user.target.name() in USES_WEIGHTS
70+
):
71+
return False
72+
73+
return True
74+
75+
for node in program.graph_module.graph.nodes:
76+
if not is_non_weight_param_tensor(node):
77+
continue
78+
79+
with program.graph_module.graph.inserting_after(node):
80+
prepack_node = program.graph_module.graph.create_node(
81+
"call_function",
82+
exir_ops.edge.et_vk.prepack.default,
83+
(node,),
84+
)
85+
prepack_node.meta["spec"] = node.meta["spec"]
86+
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
87+
# memory object. This pass must be executed AFTER the memory planning pass.
88+
prepack_node.meta["spec"].mem_obj_id = -1
89+
node.replace_all_uses_with(prepack_node, lambda x: x != prepack_node)
90+
91+
program.graph.eliminate_dead_code()
92+
return program

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ void add_binary_op_node(
5151
const ValueRef alpha,
5252
const ValueRef out,
5353
const std::string& op_name) {
54-
ValueRef arg1 = prepack_standard_like(graph, in1, out, true);
55-
ValueRef arg2 = prepack_standard_like(graph, in2, out, true);
54+
VK_CHECK_COND(graph.val_is_tensor(in1));
55+
VK_CHECK_COND(graph.val_is_tensor(in2));
5656

57-
vTensorPtr t_in1 = graph.get_tensor(arg1);
58-
vTensorPtr t_in2 = graph.get_tensor(arg2);
57+
vTensorPtr t_in1 = graph.get_tensor(in1);
58+
vTensorPtr t_in2 = graph.get_tensor(in2);
5959
vTensorPtr t_out = graph.get_tensor(out);
6060

6161
check_binary_op_args(*t_in1, *t_in2, *t_out);
@@ -81,7 +81,7 @@ void add_binary_op_node(
8181
graph.create_local_wg_size(out),
8282
// Inputs and Outputs
8383
{{out, vkapi::MemoryAccessType::WRITE},
84-
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
84+
{{in1, in2}, vkapi::MemoryAccessType::READ}},
8585
// Shader params buffers
8686
{t_out->sizes_ubo(),
8787
t_out->axis_map_ubo(),

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
911
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1012

1113
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
@@ -204,4 +206,12 @@ ValueRef prepack_direct_copy_buffer(
204206
return tensor;
205207
}
206208

209+
void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
210+
return add_standard_prepack_node(graph, args[0], args[1]);
211+
}
212+
213+
REGISTER_OPERATORS {
214+
VK_REGISTER_OP(et_vk.prepack.default, prepack_op);
215+
}
216+
207217
} // namespace vkcompute

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,11 @@ def __init__(self):
251251
self.weight = torch.rand(size=(2, 3), dtype=torch.float32)
252252

253253
def forward(self, x, y):
254-
z = torch.add(x, y, alpha=2)
255-
z = torch.add(x, y, alpha=3.14)
256-
z = z + x
257-
z = z + self.weight
258-
return z
254+
inter1 = torch.add(x, y, alpha=2)
255+
inter2 = torch.add(x, y, alpha=3.14)
256+
inter3 = inter1 * self.weight
257+
inter4 = inter2 * self.weight
258+
return inter4 - inter3
259259

260260
internal_data_module = InternalDataModule()
261261
sample_inputs = (

backends/vulkan/vulkan_preprocess.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2020

2121
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
22+
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
2223

2324
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2425
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -86,6 +87,8 @@ def preprocess( # noqa: C901
8687

8788
_copy_module(program.graph_module, new_gm)
8889

90+
program = insert_prepack_nodes(program)
91+
8992
graph_builder = VkGraphBuilder(
9093
program, DelegateMappingBuilder(generated_identifiers=True)
9194
)

0 commit comments

Comments
 (0)