|
8 | 8 | import unittest
|
9 | 9 |
|
10 | 10 | from itertools import product
|
11 |
| -from typing import Optional, Tuple |
| 11 | +from typing import Callable, Dict, List, Optional, Tuple |
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | from executorch.backends.xnnpack.partition.config.xnnpack_config import (
|
15 | 15 | ConfigPrecisionType,
|
16 | 16 | )
|
17 |
| -from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
18 |
| - |
| 17 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( |
| 18 | + XnnpackFloatingPointPartitioner, |
| 19 | + XnnpackPartitioner, |
| 20 | +) |
19 | 21 | from executorch.backends.xnnpack.test.tester import Quantize, Tester
|
20 | 22 | from executorch.backends.xnnpack.test.tester.tester import (
|
21 | 23 | Partition,
|
@@ -672,3 +674,150 @@ def _test_groupwise_dq_linear(
|
672 | 674 | .serialize()
|
673 | 675 | .run_method_and_compare_outputs(atol=atol, rtol=rtol)
|
674 | 676 | )
|
| 677 | + |
| 678 | + def _test_linear_overwrite_precision( |
| 679 | + self, |
| 680 | + make_module: Callable[[int, int], torch.nn.Module], |
| 681 | + uses_bias: bool, |
| 682 | + quant_type: str, |
| 683 | + quant_node_checks: List[Dict[str, int]], |
| 684 | + atol: float = 1e-03, |
| 685 | + ): |
| 686 | + """ |
| 687 | + This test is to test the overwrite precision of linear op. |
| 688 | + We will test partitioning, lowering, and running the quantized linear model as fp32 linear op. |
| 689 | + When using legacy_mode, we will test we don't partition [add]mm given, |
| 690 | + (1) We can't assume that weights are always static (non param). |
| 691 | + (2) Alternatively, when lowering [add]mm to xnn::bmm we can't support bias. |
| 692 | + (2)(a) Only lowering non-bias [add]mm, which is only exposed on legacy_path deemed low ROI. |
| 693 | + """ |
| 694 | + |
| 695 | + in_sizes = [3, 4, 4] |
| 696 | + input_sizes = [4, 37, 17] |
| 697 | + output_sizes = [4, 17, 37] |
| 698 | + |
| 699 | + assert quant_type in ["per_tensor", "per_channel", "per_channel_dynamic"] |
| 700 | + per_channel = "per_channel" in quant_type |
| 701 | + dynamic = "dynamic" in quant_type |
| 702 | + quant_config = get_symmetric_quantization_config( |
| 703 | + is_per_channel=per_channel, |
| 704 | + is_dynamic=dynamic, |
| 705 | + ) |
| 706 | + # Using FP32 partitioner for this quantized graph |
| 707 | + partitioner = XnnpackFloatingPointPartitioner() |
| 708 | + |
| 709 | + def get_qnode_checks(quant_node_checks, dialect): |
| 710 | + d = {} |
| 711 | + assert dialect in ["aten", "edge"] |
| 712 | + if dialect == "aten": |
| 713 | + d = { |
| 714 | + f"torch.ops.quantized_decomposed.{op}": count |
| 715 | + for op, count in quant_node_checks.items() |
| 716 | + } |
| 717 | + elif dialect == "edge": |
| 718 | + d = { |
| 719 | + f"executorch.exir.dialects.edge._ops.quantized_decomposed.{op}".replace( |
| 720 | + ".", "_" |
| 721 | + ): count |
| 722 | + for op, count in quant_node_checks.items() |
| 723 | + } |
| 724 | + assert len(d) == len(quant_node_checks) |
| 725 | + return d |
| 726 | + |
| 727 | + for i, _ in enumerate(in_sizes): |
| 728 | + torch._dynamo.reset() |
| 729 | + in_size = int(in_sizes[i]) |
| 730 | + input_size = int(input_sizes[i]) |
| 731 | + output_size = int(output_sizes[i]) |
| 732 | + input_shape = [in_size] + [input_size] |
| 733 | + module = make_module(input_size, output_size).eval() |
| 734 | + inputs = (torch.randn(input_shape),) |
| 735 | + |
| 736 | + addmm_op_str = ( |
| 737 | + "executorch_exir_dialects_edge__ops_aten_addmm_default" |
| 738 | + if uses_bias |
| 739 | + else "executorch_exir_dialects_edge__ops_aten_mm_default" |
| 740 | + ) |
| 741 | + linear_op_str = "executorch_exir_dialects_edge__ops_aten_linear_default" |
| 742 | + |
| 743 | + for legacy_mode in (True, False): |
| 744 | + tester = ( |
| 745 | + Tester(module, inputs) |
| 746 | + .quantize(Quantize(quantization_config=quant_config)) |
| 747 | + .export() |
| 748 | + .dump_artifact() |
| 749 | + .check_count(get_qnode_checks(quant_node_checks, "aten")) |
| 750 | + ) |
| 751 | + |
| 752 | + if legacy_mode: |
| 753 | + tester.to_edge() |
| 754 | + tester.partition(Partition(partitioner=partitioner)) |
| 755 | + # We don't expect [add]mm to be partitioned |
| 756 | + tester.check([addmm_op_str]) |
| 757 | + else: |
| 758 | + tester.to_edge_transform_and_lower( |
| 759 | + ToEdgeTransformAndLower(partitioners=[partitioner]) |
| 760 | + ) |
| 761 | + # We do expect linear to be partitioned |
| 762 | + tester.check_not([linear_op_str]) |
| 763 | + |
| 764 | + # For legacy mode, fp32 permute_copy gets partitioned. (just a side effect) |
| 765 | + # For new mode, fp32 linear gets partitioned. |
| 766 | + tester.check_count( |
| 767 | + {"torch.ops.higher_order.executorch_call_delegate": 1} |
| 768 | + ) |
| 769 | + |
| 770 | + # Typically, we would not see any quantized ops in the graph. |
| 771 | + # But here we shouldn't partition these. |
| 772 | + tester.check_count(get_qnode_checks(quant_node_checks, "edge")) |
| 773 | + |
| 774 | + # TODO: Need to figure out how to load quantized ops in pybindings. |
| 775 | + # tester.to_executorch() |
| 776 | + # tester.serialize() |
| 777 | + # tester.run_method_and_compare_outputs( |
| 778 | + # qtol=bool(quant_config), atol=atol |
| 779 | + # ) |
| 780 | + |
| 781 | + def test_qs8_as_fp32(self): |
| 782 | + for use_bias in (True, False): |
| 783 | + self._test_linear_overwrite_precision( |
| 784 | + lambda in_size, out_size: torch.nn.Linear( |
| 785 | + in_size, out_size, bias=use_bias # noqa |
| 786 | + ), |
| 787 | + use_bias, |
| 788 | + "per_tensor", |
| 789 | + quant_node_checks={ |
| 790 | + "quantize_per_tensor.default": 2, # 1: act, 1: output |
| 791 | + "dequantize_per_tensor.default": 3, # 1: act, 1: weight, 1: output |
| 792 | + }, |
| 793 | + ) |
| 794 | + |
| 795 | + def test_qc8_as_fp32(self): |
| 796 | + for use_bias in (True, False): |
| 797 | + self._test_linear_overwrite_precision( |
| 798 | + lambda in_size, out_size: torch.nn.Linear( |
| 799 | + in_size, out_size, bias=use_bias # noqa |
| 800 | + ), |
| 801 | + use_bias, |
| 802 | + "per_channel", |
| 803 | + quant_node_checks={ |
| 804 | + "quantize_per_tensor.default": 2, # 1: act, 1: output |
| 805 | + "dequantize_per_tensor.default": 2, # 1: act, 1: output |
| 806 | + "dequantize_per_channel.default": 1, # 1: weight |
| 807 | + }, |
| 808 | + ) |
| 809 | + |
| 810 | + def test_qd8_as_fp32(self): |
| 811 | + for use_bias in (True, False): |
| 812 | + self._test_linear_overwrite_precision( |
| 813 | + lambda in_size, out_size: torch.nn.Linear( |
| 814 | + in_size, out_size, bias=use_bias # noqa |
| 815 | + ), |
| 816 | + use_bias, |
| 817 | + "per_channel_dynamic", |
| 818 | + quant_node_checks={ |
| 819 | + "quantize_per_tensor.tensor": 1, # 1: act |
| 820 | + "dequantize_per_tensor.tensor": 1, # 1: act |
| 821 | + "dequantize_per_channel.default": 1, # 1: weight |
| 822 | + }, |
| 823 | + ) |
0 commit comments