|
6 | 6 | # Partitioner for the NXP Neutron NPU
|
7 | 7 |
|
8 | 8 | import logging
|
| 9 | +import operator |
9 | 10 | from typing import final, List
|
10 | 11 |
|
11 | 12 | import torch
|
|
43 | 44 | # exir_ops.edge.aten.sub.Scalar,
|
44 | 45 | # exir_ops.edge.aten.tanh.default,
|
45 | 46 | # operator.getitem,
|
46 |
| - |
47 |
| - # QDQ ops |
48 |
| - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
49 |
| - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
50 | 47 | ]
|
51 | 48 |
|
52 | 49 | class NeutronSupportedOperators(OperatorSupportBase):
|
53 | 50 | def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
54 |
| - # check if the PyTorch op get called is supported for Neutron |
55 |
| - return node.op == "call_function" and node.target in NeutronSupportedOperatorsList |
| 51 | + """ |
| 52 | + Check if the PyTorch op that gets called is supported for Neutron |
| 53 | + or if it is part of a QDQ cluster. |
| 54 | + """ |
| 55 | + return ( |
| 56 | + node.op == "call_function" and node.target in NeutronSupportedOperatorsList |
| 57 | + ) or "cluster" in node.meta |
56 | 58 |
|
57 | 59 | @final
|
58 | 60 | class NeutronPartitioner(Partitioner):
|
59 | 61 | def __init__(self, compile_spec: List[CompileSpec]) -> None:
|
60 | 62 | self.delegation_spec = DelegationSpec(NeutronBackend.__name__, compile_spec)
|
61 | 63 |
|
| 64 | + def is_quant_node(self, node: torch.fx.node.Node): |
| 65 | + return node.target in { |
| 66 | + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, |
| 67 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 68 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, |
| 69 | + } |
| 70 | + |
| 71 | + def is_dequant_node(self, node: torch.fx.node.Node): |
| 72 | + return node.target in { |
| 73 | + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, |
| 74 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 75 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, |
| 76 | + } |
| 77 | + |
| 78 | + def tag_clusters(self, nodes): |
| 79 | + """ |
| 80 | + Identifies clusters of nodes that involve quantisation and dequantisation |
| 81 | + operations. It tags these nodes with a cluster name, which can be used |
| 82 | + later for partitioning and optimising the graph. |
| 83 | +
|
| 84 | + Clustering is the process of grouping nodes in the computation graph that are related |
| 85 | + to quantisation and dequantisation operations. This is useful for optimising the graph |
| 86 | + for execution on specialized hardware. |
| 87 | + """ |
| 88 | + def get_dequant_inputs(node): |
| 89 | + """ |
| 90 | + This function returns all the dequant operators which produce inputs to the node. |
| 91 | + However, if the operator has 3 inputs and only one comes from dequant, the function |
| 92 | + will return true and consequently the code condition `if dequant_inputs:` will be true. |
| 93 | +
|
| 94 | + This is done to handle the unexpected behavior of the NeutronQuantizer with the bias tensor (EIEX-66). |
| 95 | + """ |
| 96 | + return [ |
| 97 | + input_node for input_node in node.args |
| 98 | + if isinstance(input_node, torch.fx.node.Node) and self.is_dequant_node(input_node) |
| 99 | + ] |
| 100 | + |
| 101 | + def get_quant_outputs(node): |
| 102 | + """ |
| 103 | + Retrieve the quantised outputs of a given node. |
| 104 | +
|
| 105 | + This function examines the outputs of the provided node to identify |
| 106 | + quantised nodes. It also checks if the output operation is a call to the |
| 107 | + `operator.getitem` function and then inspects the operator's output to |
| 108 | + find quantised nodes. |
| 109 | + """ |
| 110 | + quant_outputs = [] |
| 111 | + for user in node.users: |
| 112 | + if user.op == "call_function" and user.target == operator.getitem: |
| 113 | + for grandchild in user.users: |
| 114 | + if self.is_quant_node(grandchild): |
| 115 | + quant_outputs.append(grandchild) |
| 116 | + elif self.is_quant_node(user): |
| 117 | + quant_outputs.append(user) |
| 118 | + return quant_outputs |
| 119 | + |
| 120 | + def tag_node_and_related(node, cluster_name, dequant_inputs, quant_outputs): |
| 121 | + # Tags a node and its related dequant and quant nodes with a specified cluster name |
| 122 | + logging.info(f"Tagging node {node} as {cluster_name}") |
| 123 | + node.meta["cluster"] = cluster_name |
| 124 | + for dequant_node in dequant_inputs: |
| 125 | + dequant_node.meta["cluster"] = cluster_name |
| 126 | + for quant_node in quant_outputs: |
| 127 | + quant_node.meta["cluster"] = cluster_name |
| 128 | + |
| 129 | + for node in nodes: |
| 130 | + if node.op == "call_function": |
| 131 | + dequant_inputs = get_dequant_inputs(node) |
| 132 | + quant_outputs = get_quant_outputs(node) |
| 133 | + if dequant_inputs and quant_outputs: |
| 134 | + cluster_name = f"{node.name}_cluster" |
| 135 | + tag_node_and_related(node, cluster_name, dequant_inputs, quant_outputs) |
| 136 | + |
62 | 137 | def partition(self, exported_program: ExportedProgram) -> PartitionResult:
|
63 | 138 | # Run the CapabilityBasedPartitioner to return the largest possible
|
64 | 139 | # subgraphs containing the nodes with the tags
|
65 | 140 | logging.info("NeutronPartitioner::partition")
|
66 | 141 | partition_tags = {}
|
67 | 142 |
|
| 143 | + graph_module = exported_program.graph_module |
| 144 | + nodes = list(graph_module.graph.nodes) |
| 145 | + |
| 146 | + self.tag_clusters(nodes) |
| 147 | + |
| 148 | + graph_module.recompile() |
| 149 | + |
68 | 150 | capability_partitioner = CapabilityBasedPartitioner(
|
69 | 151 | exported_program.graph_module,
|
70 | 152 | NeutronSupportedOperators(),
|
|
0 commit comments