17
17
from executorch .backends .transforms .fuse_view_copy import FuseViewCopyTransform
18
18
from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
19
19
20
- from executorch .backends .vulkan ._passes import RemoveLocalScalarDenseOpsTransform
21
- from executorch .backends .vulkan ._passes .insert_prepack_nodes import insert_prepack_nodes
20
+ from executorch .backends .vulkan ._passes import (
21
+ insert_prepack_nodes ,
22
+ RemoveLocalScalarDenseOpsTransform ,
23
+ )
22
24
23
25
from executorch .backends .vulkan .serialization .vulkan_graph_builder import VkGraphBuilder
24
26
from executorch .backends .vulkan .serialization .vulkan_graph_serialize import (
32
34
PreprocessResult ,
33
35
)
34
36
from executorch .exir .backend .utils import DelegateMappingBuilder
37
+ from executorch .exir .pass_base import ExportPass , PassBase
35
38
36
39
from executorch .exir .passes import MemoryPlanningPass , SpecPropPass
37
40
46
49
DEFAULT_DEBUG_HANDLE = 65535
47
50
48
51
52
+ # pyre-ignore
53
+ def apply_passes (program : ExportedProgram , passes ) -> ExportedProgram :
54
+ for p in passes :
55
+
56
+ if issubclass (type (p ), ExportPass ) or issubclass (type (p ), PassBase ):
57
+ new_gm = program .graph_module
58
+ # This is a workaround to allow the memory planning pass to work without
59
+ # having to first apply ToOutVarPass(). See the `greedy()` function in
60
+ # `exir.memory_planning`; if this attribute isn't set, assertions in
61
+ # `collect_spec_from_nodes()` will fail.
62
+ if isinstance (p , MemoryPlanningPass ):
63
+ new_gm .encounter_to_out_var_failure = True
64
+
65
+ new_gm_res = p (new_gm )
66
+ assert new_gm_res is not None
67
+ new_gm = new_gm_res .graph_module
68
+
69
+ # See the application of this function in exir/program/_program.py for more
70
+ # details on why this step is necessary.
71
+ if isinstance (p , SpecPropPass ):
72
+ p .update_placeholder_tensor_specs (program , new_gm )
73
+
74
+ _copy_module (program .graph_module , new_gm )
75
+ else :
76
+ program = p (program )
77
+
78
+ return program
79
+
80
+
49
81
@final
50
82
class VulkanBackend (BackendDetails ):
51
83
@classmethod
@@ -57,35 +89,44 @@ def preprocess( # noqa: C901
57
89
) -> PreprocessResult :
58
90
program = unsafe_remove_auto_functionalized_pass (program )
59
91
60
- passes = [
61
- RemoveCloneOpsTransform (),
62
- AddmmToLinearTransform (),
63
- FuseDequantLinearPass (),
64
- FuseViewCopyTransform (),
65
- FuseBatchNormWithConvPass (program ),
66
- FuseClampPass (),
67
- SpecPropPass (),
68
- ConstraintBasedSymShapeEvalPass (),
69
- RemoveLocalScalarDenseOpsTransform (),
70
- MemoryPlanningPass (),
71
- ]
72
-
73
- new_gm = program .graph_module
74
-
75
- for p in passes :
76
- # This is a workaround to allow the memory planning pass to work without
77
- # having to first apply ToOutVarPass(). See the `greedy()` function in
78
- # `exir.memory_planning`; if this attribute isn't set, assertions in
79
- # `collect_spec_from_nodes()` will fail.
80
- if isinstance (p , MemoryPlanningPass ):
81
- new_gm .encounter_to_out_var_failure = True
82
- new_gm_res = p (new_gm )
83
- assert new_gm_res is not None
84
- new_gm = new_gm_res .graph_module
92
+ # First, apply passes that fuse/remove operators to consolidate the graph
93
+ # structure but still preserve an "ATen-compliant" graph structure (i.e. all
94
+ # arguments to ATen operators must match the ATen function schema).
95
+ program = apply_passes (
96
+ program ,
97
+ [
98
+ RemoveCloneOpsTransform (),
99
+ AddmmToLinearTransform (),
100
+ FuseDequantLinearPass (),
101
+ FuseViewCopyTransform (),
102
+ FuseBatchNormWithConvPass (program ),
103
+ FuseClampPass (),
104
+ ],
105
+ )
85
106
86
- _copy_module (program .graph_module , new_gm )
107
+ # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
108
+ # shapes and memory planning. Until this point, the graph must be ATen compliant
109
+ # because SpecPropPass will be calling the underlying ATen operators during its
110
+ # execution.
111
+ program = apply_passes (program , [SpecPropPass ()])
112
+
113
+ # Apply graph transforms which either require `TensorSpec`s to have been created
114
+ # or would create an non ATen compliant graph structure.
115
+ program = apply_passes (
116
+ program ,
117
+ [
118
+ # Since this pass may replace a scalar argument with a tensor argument,
119
+ # this pass may result in a non ATen compliant graph structure.
120
+ RemoveLocalScalarDenseOpsTransform (),
121
+ insert_prepack_nodes ,
122
+ ],
123
+ )
87
124
88
- program = insert_prepack_nodes (program )
125
+ # Finally, apply dynamic shape passes and memory planning pass. These passes
126
+ # must be applied only when the graph structure is finalized.
127
+ program = apply_passes (
128
+ program , [ConstraintBasedSymShapeEvalPass (), MemoryPlanningPass ()]
129
+ )
89
130
90
131
graph_builder = VkGraphBuilder (
91
132
program , DelegateMappingBuilder (generated_identifiers = True )
0 commit comments