|
7 | 7 | # pyre-strict
|
8 | 8 |
|
9 | 9 | import logging
|
10 |
| -from typing import Any, Dict, final, List, Mapping, Optional |
| 10 | +from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple |
11 | 11 |
|
12 | 12 | import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
|
13 | 13 |
|
|
33 | 33 |
|
34 | 34 | from torch.fx.passes.operator_support import OperatorSupportBase
|
35 | 35 |
|
| 36 | +# pyre-ignore |
| 37 | +ops_not_to_decompose = [ |
| 38 | + torch.ops.aten.upsample_nearest2d.vec, |
| 39 | +] |
| 40 | + |
| 41 | +# pyre-ignore |
| 42 | +edge_ops_non_decomposed = [ |
| 43 | + exir_ops.edge.aten.upsample_nearest2d.vec, |
| 44 | +] |
| 45 | + |
36 | 46 |
|
37 | 47 | class VulkanSupportedOperators(OperatorSupportBase):
|
38 | 48 | _ops: OpList = enumerate_supported_ops()
|
@@ -117,6 +127,9 @@ def _is_node_supported(
|
117 | 127 | if node.target not in VulkanSupportedOperators._ops:
|
118 | 128 | return False
|
119 | 129 |
|
| 130 | + if node.op == "call_function" and node.target in edge_ops_non_decomposed: |
| 131 | + return True |
| 132 | + |
120 | 133 | features = VulkanSupportedOperators._ops[node.target]
|
121 | 134 |
|
122 | 135 | if self.require_dynamic_shapes and not features.supports_dynamic_shape:
|
@@ -150,6 +163,11 @@ def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None:
|
150 | 163 | compile_spec = parse_compile_options(self.options)
|
151 | 164 | self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
|
152 | 165 |
|
| 166 | + def ops_to_not_decompose( |
| 167 | + self, ep: ExportedProgram |
| 168 | + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: |
| 169 | + return (ops_not_to_decompose, None) |
| 170 | + |
153 | 171 | def partition(self, exported_program: ExportedProgram) -> PartitionResult:
|
154 | 172 | # Run the CapabilityBasedPartitioner to return the largest possible
|
155 | 173 | # subgraphs containing the nodes with the tags
|
|
0 commit comments