Skip to content

[ET-VK] Implement prepack nodes #6352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/transforms/fuse_conv_with_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def call(self, graph_module: torch.fx.GraphModule):
with graph_module.graph.inserting_before(preceding_op):
conv_activation_node = graph_module.graph.create_node(
"call_function",
torch.ops.et_vk.conv_with_clamp.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
new_args,
)

Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ python_unittest(
],
)

runtime.python_library(
name = "insert_prepack_nodes",
srcs = ["insert_prepack_nodes.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
],
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
Expand Down Expand Up @@ -65,6 +77,7 @@ runtime.python_library(
"//executorch/examples/...",
],
deps = [
":insert_prepack_nodes",
":int4_weight_only_quantizer",
":remove_local_scalar_dense",
]
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
VkInt4WeightOnlyQuantizer,
)
Expand All @@ -6,6 +7,7 @@
)

__all__ = [
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
"RemoveLocalScalarDenseOpsTransform",
]
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/custom_ops_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
namespace = "et_vk"
lib = torch.library.Library(namespace, "DEF")

#############
## prepack ##
#############


def prepack_impl(x: torch.Tensor):
return x


name = "prepack"
lib.define(f"{name}(Tensor x) -> Tensor")
lib.impl(name, prepack_impl, "CompositeExplicitAutograd")
prepack_op = getattr(getattr(torch.ops, namespace), name)

#####################
## conv_with_clamp ##
#####################
Expand Down
92 changes: 92 additions & 0 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import List

import executorch.backends.vulkan._passes.custom_ops_defs # noqa

import torch

from executorch.exir.dialects._ops import ops as exir_ops

from torch._export.utils import is_buffer, is_param
from torch.export import ExportedProgram

USES_WEIGHTS: List[torch._ops.OpOverload] = [
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten._weight_int8pack_mm.default,
exir_ops.edge.et_vk.linear_weight_int4.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
"llama::sdpa_with_kv_cache",
]


def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
"""
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
is responsible for transferring the tensor data, which is serialized with the model,
to a GPU tensor object during the prepacking stage of model execution.

Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
prefer to handle prepacking within the operator. For these ops, the constant tensor
data will be passed directly as an argument into the operator implementation.
"""

def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"

def is_constant(node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants

def is_param_node(node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_constant(node)
)

def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
if not is_param_node(node):
return False

for user in node.users:
if user.op == "call_function" and (
# pyre-ignore [16]
user.target in USES_WEIGHTS
or user.target.name() in USES_WEIGHTS
):
return False

return True

for node in program.graph_module.graph.nodes:
if not is_non_weight_param_tensor(node):
continue

with program.graph_module.graph.inserting_after(node):
prepack_node = program.graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.prepack.default,
(node,),
)
prepack_node.meta["spec"] = node.meta["spec"]
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object. This pass must be executed AFTER the memory planning pass.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)

program.graph.eliminate_dead_code()
return program
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
Expand Down Expand Up @@ -205,4 +207,12 @@ ValueRef prepack_direct_copy_buffer(
return tensor;
}

void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_prepack_standard_node(graph, args[0], args[1]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(et_vk.prepack.default, prepack_op);
}

} // namespace vkcompute
10 changes: 5 additions & 5 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ def __init__(self):
self.weight = torch.rand(size=(2, 3), dtype=torch.float32)

def forward(self, x, y):
z = torch.add(x, y, alpha=2)
z = torch.add(x, y, alpha=3.14)
z = z + x
z = z + self.weight
return z
inter1 = torch.add(x, y, alpha=2)
inter2 = torch.add(x, y, alpha=3.14)
inter3 = inter1 * self.weight
inter4 = inter2 * self.weight
return inter4 - inter3

internal_data_module = InternalDataModule()
sample_inputs = (
Expand Down
60 changes: 26 additions & 34 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
ValueRef c = graph.add_tensor(size_big, vkapi::kFloat);
ValueRef e = graph.add_tensor(size_big, vkapi::kFloat);

ValueRef w1_packed = graph.add_tensor(size_small, vkapi::kFloat);
ValueRef w2_packed = graph.add_tensor(size_small, vkapi::kFloat);

auto prepackFn = VK_GET_OP_FN("et_vk.prepack.default");
prepackFn(graph, {w1, w1_packed});
prepackFn(graph, {w2, w2_packed});

auto addFn = VK_GET_OP_FN("aten.add.Tensor");
addFn(graph, {a.value, w1, kDummyValueRef, c});
addFn(graph, {a.value, w1_packed, kDummyValueRef, c});

auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
mulFn(graph, {c, w2, e});
mulFn(graph, {c, w2_packed, e});

IOValueRef out = {};
out.value = e;
Expand Down Expand Up @@ -2597,24 +2604,16 @@ void test_binary_op(
std::vector<int64_t> sizes_big,
std::vector<int64_t> sizes_small,
vkapi::ScalarType dtype,
utils::GPUMemoryLayout memory_layout,
bool prepack = true) {
utils::GPUMemoryLayout memory_layout) {
GraphConfig config;
ComputeGraph graph(config);

IOValueRef arg2{};

CREATE_WEIGHT_TENSOR(arg2_w, sizes_small, dtype, 2.5f);

// Build graph

IOValueRef arg1 = graph.add_input_tensor(sizes_big, dtype, memory_layout);

if (prepack) {
arg2.value = arg2_w;
} else {
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);
}
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);

IOValueRef out;
out.value = graph.add_tensor(sizes_big, dtype, memory_layout);
Expand All @@ -2635,7 +2634,7 @@ void test_binary_op(

for (int i = 1; i < 4; i++) {
float val_arg1 = i + 1.5;
float val_arg2 = prepack ? 2.5f : i - 3.5;
float val_arg2 = i - 3.5;

float val_out = val_arg1 + val_arg2;
if (op_name == "sub") {
Expand All @@ -2648,21 +2647,14 @@ void test_binary_op(
val_out = val_arg1 / val_arg2;
}

if (prepack) {
execute_graph_and_check_output(graph, {val_arg1}, {val_out});
} else {
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
}
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
}
}

#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true)
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked) \
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked) \
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked)

#define CALL_TEST_FN_FOR_W_PACKED(_) \
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
Expand All @@ -2677,15 +2669,15 @@ void test_binary_op(
_(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true)

TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
#define RUN_TESTS(dtype, storage, layout, prepack) \
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout, prepack); \
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout, prepack); \
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout, prepack); \
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout, prepack); \
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout, prepack);
#define RUN_TESTS(dtype, storage, layout) \
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout); \
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout); \
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout); \
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout); \
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout); \
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout); \
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout); \
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout);

CALL_TEST_FN_FORALL_CONDITIONS(RUN_TESTS);

Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand Down Expand Up @@ -86,6 +87,8 @@ def preprocess( # noqa: C901

_copy_module(program.graph_module, new_gm)

program = insert_prepack_nodes(program)

graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
)
Expand Down
Loading