Skip to content

Commit da51c81

Browse files
committed
[ET-VK][AOT] Define pass application order
Pull Request resolved: #6577 ## Changes The goal of this diff is to enforce a specific structure in how graph transform passes are applied during `vulkan_preprocess`. This will help make sure that certain passes are applied at the correct time, and that pre-requisite conditions for passes are fulfilled before they are applied. See the comments in `vulkan_preprocess.py` for more details. ghstack-source-id: 251223076 Differential Revision: [D65234843](https://our.internmc.facebook.com/intern/diff/D65234843/)
1 parent 1972e69 commit da51c81

File tree

2 files changed

+80
-31
lines changed

2 files changed

+80
-31
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from copy import deepcopy
10+
911
import executorch.backends.vulkan.custom_ops_lib # noqa
1012

1113
import torch
@@ -69,9 +71,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6971
exir_ops.edge.et_vk.prepack.default,
7072
(node,),
7173
)
72-
prepack_node.meta["spec"] = node.meta["spec"]
74+
# This pass assumes that the SpecPropPass() has already been applied
75+
assert "spec" in node.meta
76+
# Validate that the original node is marked as a constant. Constant tensors
77+
# do not participate in memory planning.
78+
assert node.meta["spec"].const
79+
prepack_node.meta["val"] = node.meta["val"]
80+
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
7381
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
74-
# memory object. This pass must be executed AFTER the memory planning pass.
82+
# memory object.
7583
prepack_node.meta["spec"].mem_obj_id = -1
7684
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
7785

backends/vulkan/vulkan_preprocess.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
1818
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
1919

20-
from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
21-
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
20+
from executorch.backends.vulkan._passes import (
21+
insert_prepack_nodes,
22+
RemoveLocalScalarDenseOpsTransform,
23+
)
2224

2325
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
2426
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
@@ -32,6 +34,7 @@
3234
PreprocessResult,
3335
)
3436
from executorch.exir.backend.utils import DelegateMappingBuilder
37+
from executorch.exir.pass_base import ExportPass, PassBase
3538

3639
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
3740

@@ -46,6 +49,35 @@
4649
DEFAULT_DEBUG_HANDLE = 65535
4750

4851

52+
# pyre-ignore
53+
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
54+
for p in passes:
55+
56+
if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
57+
new_gm = program.graph_module
58+
# This is a workaround to allow the memory planning pass to work without
59+
# having to first apply ToOutVarPass(). See the `greedy()` function in
60+
# `exir.memory_planning`; if this attribute isn't set, assertions in
61+
# `collect_spec_from_nodes()` will fail.
62+
if isinstance(p, MemoryPlanningPass):
63+
new_gm.encounter_to_out_var_failure = True
64+
65+
new_gm_res = p(new_gm)
66+
assert new_gm_res is not None
67+
new_gm = new_gm_res.graph_module
68+
69+
# See the application of this function in exir/program/_program.py for more
70+
# details on why this step is necessary.
71+
if isinstance(p, SpecPropPass):
72+
p.update_placeholder_tensor_specs(program, new_gm)
73+
74+
_copy_module(program.graph_module, new_gm)
75+
else:
76+
program = p(program)
77+
78+
return program
79+
80+
4981
@final
5082
class VulkanBackend(BackendDetails):
5183
@classmethod
@@ -57,35 +89,44 @@ def preprocess( # noqa: C901
5789
) -> PreprocessResult:
5890
program = unsafe_remove_auto_functionalized_pass(program)
5991

60-
passes = [
61-
RemoveCloneOpsTransform(),
62-
AddmmToLinearTransform(),
63-
FuseDequantLinearPass(),
64-
FuseViewCopyTransform(),
65-
FuseBatchNormWithConvPass(program),
66-
FuseClampPass(),
67-
SpecPropPass(),
68-
ConstraintBasedSymShapeEvalPass(),
69-
RemoveLocalScalarDenseOpsTransform(),
70-
MemoryPlanningPass(),
71-
]
72-
73-
new_gm = program.graph_module
74-
75-
for p in passes:
76-
# This is a workaround to allow the memory planning pass to work without
77-
# having to first apply ToOutVarPass(). See the `greedy()` function in
78-
# `exir.memory_planning`; if this attribute isn't set, assertions in
79-
# `collect_spec_from_nodes()` will fail.
80-
if isinstance(p, MemoryPlanningPass):
81-
new_gm.encounter_to_out_var_failure = True
82-
new_gm_res = p(new_gm)
83-
assert new_gm_res is not None
84-
new_gm = new_gm_res.graph_module
92+
# First, apply passes that fuse/remove operators to consolidate the graph
93+
# structure but still preserve an "ATen-compliant" graph structure (i.e. all
94+
# arguments to ATen operators must match the ATen function schema).
95+
program = apply_passes(
96+
program,
97+
[
98+
RemoveCloneOpsTransform(),
99+
AddmmToLinearTransform(),
100+
FuseDequantLinearPass(),
101+
FuseViewCopyTransform(),
102+
FuseBatchNormWithConvPass(program),
103+
FuseClampPass(),
104+
],
105+
)
85106

86-
_copy_module(program.graph_module, new_gm)
107+
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
108+
# shapes and memory planning. Until this point, the graph must be ATen compliant
109+
# because SpecPropPass will be calling the underlying ATen operators during its
110+
# execution.
111+
program = apply_passes(program, [SpecPropPass()])
112+
113+
# Apply graph transforms which either require `TensorSpec`s to have been created
114+
# or would create an non ATen compliant graph structure.
115+
program = apply_passes(
116+
program,
117+
[
118+
# Since this pass may replace a scalar argument with a tensor argument,
119+
# this pass may result in a non ATen compliant graph structure.
120+
RemoveLocalScalarDenseOpsTransform(),
121+
insert_prepack_nodes,
122+
],
123+
)
87124

88-
program = insert_prepack_nodes(program)
125+
# Finally, apply dynamic shape passes and memory planning pass. These passes
126+
# must be applied only when the graph structure is finalized.
127+
program = apply_passes(
128+
program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()]
129+
)
89130

90131
graph_builder = VkGraphBuilder(
91132
program, DelegateMappingBuilder(generated_identifiers=True)

0 commit comments

Comments
 (0)