|
16 | 16 | QUANT_ANNOTATION_KEY,
|
17 | 17 | )
|
18 | 18 | from executorch.exir.dialects._ops import ops as exir_ops
|
| 19 | +from torch.ao.quantization.observer import MinMaxObserver |
19 | 20 | from torch.ao.quantization.quantizer import (
|
20 | 21 | QuantizationAnnotation,
|
21 | 22 | SharedQuantizationSpec,
|
22 | 23 | )
|
23 | 24 | from torch.fx import Node
|
24 | 25 |
|
25 | 26 |
|
| 27 | +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: |
| 28 | + """ |
| 29 | + This function is specific for matmul op 16a8w. |
| 30 | + """ |
| 31 | + |
| 32 | + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): |
| 33 | + input_qspec_map = {} |
| 34 | + input_act = node.args[0] |
| 35 | + input_spec = quantization_config.input_activation |
| 36 | + input_qspec_map[input_act] = input_spec |
| 37 | + |
| 38 | + input_act1 = node.args[1] |
| 39 | + input_spec1 = quantization_config.weight |
| 40 | + input_qspec_map[input_act1] = input_spec1 |
| 41 | + |
| 42 | + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| 43 | + input_qspec_map=input_qspec_map, |
| 44 | + output_qspec=quantization_config.output_activation, |
| 45 | + _annotated=True, |
| 46 | + ) |
| 47 | + |
| 48 | + def annotate_cat(node: Node, quantization_config: QuantizationConfig): |
| 49 | + input_nodes = node.args[0] |
| 50 | + |
| 51 | + first_input_node = input_nodes[0] |
| 52 | + input_qspec_map = {} |
| 53 | + input_qspec_map[first_input_node] = quantization_config.input_activation |
| 54 | + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( |
| 55 | + (first_input_node, node) |
| 56 | + ) |
| 57 | + |
| 58 | + for input_node in input_nodes[1:]: |
| 59 | + if input_node not in input_qspec_map: |
| 60 | + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec |
| 61 | + |
| 62 | + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| 63 | + input_qspec_map=input_qspec_map, |
| 64 | + output_qspec=share_qparams_with_input_act0_qspec, |
| 65 | + _annotated=True, |
| 66 | + ) |
| 67 | + |
| 68 | + def annotate_single_in_single_out( |
| 69 | + node: Node, quantization_config: QuantizationConfig |
| 70 | + ) -> None: |
| 71 | + |
| 72 | + input_qspec_map = {} |
| 73 | + input_act = node.args[0] |
| 74 | + input_qspec_map[input_act] = quantization_config.input_activation |
| 75 | + |
| 76 | + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| 77 | + input_qspec_map=input_qspec_map, |
| 78 | + output_qspec=quantization_config.output_activation, |
| 79 | + _annotated=True, |
| 80 | + ) |
| 81 | + |
| 82 | + def annotate_matmul_input1(node: Node): |
| 83 | + quantization_config_8a8w = get_default_8bit_qnn_ptq_config( |
| 84 | + act_symmetric=True, act_observer=MinMaxObserver |
| 85 | + ) |
| 86 | + while isinstance(node, Node) and node.op == "call_function": |
| 87 | + if node.target in [ |
| 88 | + torch.ops.aten.permute.default, |
| 89 | + torch.ops.aten.transpose.int, |
| 90 | + ]: |
| 91 | + annotate_single_in_single_out(node, quantization_config_8a8w) |
| 92 | + node = node.args[0] |
| 93 | + elif node.target == torch.ops.aten.cat.default: |
| 94 | + annotate_cat(node, quantization_config_8a8w) |
| 95 | + node = node.args[0][0] |
| 96 | + else: |
| 97 | + node = node.args[0] |
| 98 | + |
| 99 | + quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) |
| 100 | + |
| 101 | + for node in gm.graph.nodes: |
| 102 | + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
| 103 | + annotate_matmul(node, quantization_config_16a8w) |
| 104 | + annotate_matmul_input1(node.args[1]) |
| 105 | + |
| 106 | + |
26 | 107 | def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
|
27 | 108 | """
|
28 | 109 | This function is specific for llama matmul op 16a8w.
|
|
0 commit comments