Skip to content

[ET-VK][AOT][ez] Introduce vulkan export utils lib #6605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ runtime.python_library(
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/backends/vulkan:utils_lib",
],
)

Expand Down
33 changes: 12 additions & 21 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

# pyre-strict

from copy import deepcopy

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch

from executorch.backends.vulkan.op_registry import handles_own_prepacking
from executorch.backends.vulkan.utils import is_param_node

from executorch.exir.dialects._ops import ops as exir_ops

from torch._export.utils import is_buffer, is_param
from torch.export import ExportedProgram


Expand All @@ -29,25 +31,8 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
argument into the operator implementation.
"""

def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"

def is_constant(node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants

def is_param_node(node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_constant(node)
)

def prepack_not_required(node: torch.fx.Node) -> bool:
if not is_param_node(node):
if not is_param_node(program, node):
return True

for user in node.users:
Expand All @@ -69,9 +54,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
exir_ops.edge.et_vk.prepack.default,
(node,),
)
prepack_node.meta["spec"] = node.meta["spec"]
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
prepack_node.meta["val"] = node.meta["val"]
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object. This pass must be executed AFTER the memory planning pass.
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)

Expand Down
37 changes: 10 additions & 27 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

import torch
from executorch.backends.vulkan.utils import (
is_constant,
is_get_attr_node,
is_param_node,
)
from executorch.exir.backend.utils import DelegateMappingBuilder

from executorch.exir.tensor import TensorSpec
Expand Down Expand Up @@ -68,34 +73,12 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

def is_constant(self, node: Node):
return (
node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants
)

def is_get_attr_node(self, node: Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, Node) and node.op == "get_attr"

def is_param_node(self, node: Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
self.is_get_attr_node(node)
or is_param(self.program, node)
or is_buffer(self.program, node)
or self.is_constant(node)
)

def get_constant(self, node: Node) -> Optional[torch.Tensor]:
"""
Returns the constant associated with the given node in the exported program.
Returns None if the node is not a constant within the exported program
"""
if self.is_constant(node):
if is_constant(self.program, node):
constant_name = (
self.program.graph_signature.inputs_to_lifted_tensor_constants[
node.name
Expand All @@ -116,9 +99,9 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:
tensor = get_param(self.program, node)
elif is_buffer(self.program, node):
tensor = get_buffer(self.program, node)
elif self.is_constant(node):
elif is_constant(self.program, node):
tensor = self.get_constant(node)
elif self.is_get_attr_node(node):
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
tensor = getattr(node.graph.owning_module, node.target)
Expand All @@ -132,7 +115,7 @@ def get_param_tensor(self, node: Node) -> torch.Tensor:

def maybe_add_constant_tensor(self, node: Node) -> int:
constant_id = -1
if self.is_param_node(node):
if is_param_node(self.program, node):
constant_id = len(self.const_tensors)
self.const_tensors.append(self.get_param_tensor(node))

Expand Down Expand Up @@ -280,7 +263,7 @@ def process_placeholder_node(self, node: Node) -> None:
if len(node.users) == 0:
return None
ids = self.create_node_value(node)
if not self.is_param_node(node):
if not is_param_node(self.program, node):
if isinstance(ids, int):
self.input_ids.append(ids)
else:
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ def define_common_targets(is_fbcode = False):
## AOT targets
##
if is_fbcode:
runtime.python_library(
name = "utils_lib",
srcs = [
"utils.py",
],
visibility = [
"//executorch/backends/vulkan/...",
],
deps = [
"//caffe2:torch",
]
)

runtime.python_library(
name = "custom_ops_lib",
srcs = [
Expand Down
30 changes: 30 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch._export.utils import is_buffer, is_param

from torch.export import ExportedProgram


def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"


def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants


def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_constant(program, node)
)
99 changes: 70 additions & 29 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes import (
insert_prepack_nodes,
RemoveLocalScalarDenseOpsTransform,
)

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand All @@ -32,6 +34,7 @@
PreprocessResult,
)
from executorch.exir.backend.utils import DelegateMappingBuilder
from executorch.exir.pass_base import ExportPass, PassBase

from executorch.exir.passes import MemoryPlanningPass, SpecPropPass

Expand All @@ -46,6 +49,35 @@
DEFAULT_DEBUG_HANDLE = 65535


# pyre-ignore
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
for p in passes:

if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
new_gm = program.graph_module
# This is a workaround to allow the memory planning pass to work without
# having to first apply ToOutVarPass(). See the `greedy()` function in
# `exir.memory_planning`; if this attribute isn't set, assertions in
# `collect_spec_from_nodes()` will fail.
if isinstance(p, MemoryPlanningPass):
new_gm.encounter_to_out_var_failure = True

new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module

# See the application of this function in exir/program/_program.py for more
# details on why this step is necessary.
if isinstance(p, SpecPropPass):
p.update_placeholder_tensor_specs(program, new_gm)

_copy_module(program.graph_module, new_gm)
else:
program = p(program)

return program


@final
class VulkanBackend(BackendDetails):
@classmethod
Expand All @@ -57,35 +89,44 @@ def preprocess( # noqa: C901
) -> PreprocessResult:
program = unsafe_remove_auto_functionalized_pass(program)

passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
SpecPropPass(),
ConstraintBasedSymShapeEvalPass(),
RemoveLocalScalarDenseOpsTransform(),
MemoryPlanningPass(),
]

new_gm = program.graph_module

for p in passes:
# This is a workaround to allow the memory planning pass to work without
# having to first apply ToOutVarPass(). See the `greedy()` function in
# `exir.memory_planning`; if this attribute isn't set, assertions in
# `collect_spec_from_nodes()` will fail.
if isinstance(p, MemoryPlanningPass):
new_gm.encounter_to_out_var_failure = True
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
# First, apply passes that fuse/remove operators to consolidate the graph
# structure but still preserve an "ATen-compliant" graph structure (i.e. all
# arguments to ATen operators must match the ATen function schema).
program = apply_passes(
program,
[
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
],
)

_copy_module(program.graph_module, new_gm)
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
# shapes and memory planning. Until this point, the graph must be ATen compliant
# because SpecPropPass will be calling the underlying ATen operators during its
# execution.
program = apply_passes(program, [SpecPropPass()])

# Apply graph transforms which either require `TensorSpec`s to have been created
# or would create an non ATen compliant graph structure.
program = apply_passes(
program,
[
# Since this pass may replace a scalar argument with a tensor argument,
# this pass may result in a non ATen compliant graph structure.
RemoveLocalScalarDenseOpsTransform(),
insert_prepack_nodes,
],
)

program = insert_prepack_nodes(program)
# Finally, apply dynamic shape passes and memory planning pass. These passes
# must be applied only when the graph structure is finalized.
program = apply_passes(
program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()]
)

graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
Expand Down
Loading