@@ -676,62 +676,47 @@ def _validate_args(args):
676
676
)
677
677
678
678
679
- def _to_edge_and_lower_llama_xnnpack (
680
- builder_exported ,
681
- modelname ,
682
- additional_passes ,
683
- pt2e_quant_params ,
684
- quantizers ,
685
- quant_dtype ,
686
- args ,
687
- ) -> LLMEdgeManager : # noqa: C901
688
- partitioners = []
689
-
690
- # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
691
- partitioners .append (get_xnnpack_partitioner (dynamic_quant_only_partitioner = True ))
692
-
693
- modelname = f"xnnpack_dq_{ modelname } "
694
-
695
- if args .xnnpack_extended_ops :
696
- partitioners .append (
697
- get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
698
- )
699
- modelname = f"xnnpack_{ modelname } "
700
-
701
- logging .info ("Lowering model using following partitioner(s): " )
702
- for partitioner in partitioners :
703
- logging .info (f"--> { partitioner .__class__ .__name__ } " )
679
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
680
+ _validate_args (args )
704
681
705
- # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706
- if args .generate_etrecord :
707
- raise NotImplementedError (
708
- "export_llama does not support XNNPack and generating ETRecord at the moment."
709
- )
682
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
710
683
711
- builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
712
- partitioners
713
- )
714
- if args .verbose :
715
- print_delegation_info (builder .edge_manager .exported_program ().graph_module )
684
+ # export_to_edge
685
+ builder_exported = _prepare_for_llama_export (args ).export ()
716
686
717
- return builder . to_executorch ( passes = additional_passes )
687
+ builder_exported . run_canonical_optimizations ( )
718
688
689
+ if args .export_only :
690
+ exit ()
719
691
720
- def _to_edge_and_lower_llama ( # noqa: C901
721
- builder_exported ,
722
- modelname ,
723
- additional_passes ,
724
- pt2e_quant_params ,
725
- quantizers ,
726
- quant_dtype ,
727
- args ,
728
- ):
729
692
builder_exported_to_edge = builder_exported .pt2e_quantize (
730
693
quantizers
731
694
).export_to_edge ()
732
695
696
+ modelname = builder_exported_to_edge .modelname
697
+
733
698
# to_backend
734
699
partitioners = []
700
+
701
+ # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702
+ if (
703
+ pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None
704
+ ) or (args .xnnpack ):
705
+ partitioners .append (
706
+ get_xnnpack_partitioner (dynamic_quant_only_partitioner = True )
707
+ )
708
+
709
+ # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710
+ args .xnnpack = True
711
+ modelname = f"xnnpack_dq_{ modelname } "
712
+
713
+ if args .xnnpack_extended_ops :
714
+ assert args .xnnpack , "xnnpack_extended_ops requires xnnpack to be enabled"
715
+ partitioners .append (
716
+ get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
717
+ )
718
+ modelname = f"xnnpack_{ modelname } "
719
+
735
720
if args .vulkan :
736
721
partitioners .append (
737
722
get_vulkan_partitioner (
@@ -746,6 +731,7 @@ def _to_edge_and_lower_llama( # noqa: C901
746
731
modelname = f"vulkan_{ modelname } "
747
732
748
733
# Need to remove asserts from the graph to prevent graph breaks
734
+ # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
749
735
remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
750
736
751
737
if args .mps :
@@ -774,11 +760,13 @@ def _to_edge_and_lower_llama( # noqa: C901
774
760
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
775
761
from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
776
762
763
+ # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
777
764
_transform (builder_exported_to_edge .edge_manager .exported_program ())
778
765
779
766
if args .num_sharding > 0 :
780
767
model_sharding .split_graph (
781
768
builder_exported_to_edge .edge_manager .exported_program (),
769
+ # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
782
770
builder_exported_to_edge .metadata ["get_n_layers" ],
783
771
shares = args .num_sharding ,
784
772
)
@@ -804,15 +792,19 @@ def _to_edge_and_lower_llama( # noqa: C901
804
792
atten .head_dim ,
805
793
)
806
794
)
795
+ # pyre-ignore
807
796
tag_quant_io (
808
797
builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
809
- partial (get_custom_quant_ios_dtype , cache_shape ),
798
+ partial (get_custom_quant_ios_dtype , cache_shape ), # pyre-ignore
810
799
)
811
800
812
801
logging .info ("Lowering model using following partitioner(s): " )
813
802
for partitioner in partitioners :
814
803
logging .info (f"--> { partitioner .__class__ .__name__ } " )
815
804
805
+ additional_passes = []
806
+ if args .model in TORCHTUNE_DEFINED_MODELS :
807
+ additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
816
808
if args .generate_etrecord :
817
809
if not builder_exported_to_edge .edge_manager :
818
810
raise ValueError ("Unable to generate etrecord due to missing edge manager." )
@@ -826,6 +818,7 @@ def _to_edge_and_lower_llama( # noqa: C901
826
818
if args .num_sharding > 0 and args .qnn :
827
819
from executorch .backends .qualcomm .utils .utils import canonicalize_program
828
820
821
+ # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
829
822
canonicalize_program (builder .edge_manager .exported_program ())
830
823
831
824
builder = builder .to_executorch (
@@ -847,55 +840,11 @@ def _to_edge_and_lower_llama( # noqa: C901
847
840
if args .num_sharding > 0 and args .qnn :
848
841
from executorch .backends .qualcomm .utils .utils import canonicalize_program
849
842
843
+ # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
850
844
canonicalize_program (builder .edge_manager .exported_program ())
851
845
852
846
builder = builder .to_executorch (passes = additional_passes )
853
847
854
- return builder
855
-
856
-
857
- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
858
- _validate_args (args )
859
-
860
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
861
-
862
- additional_passes = []
863
- if args .model in TORCHTUNE_DEFINED_MODELS :
864
- additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
865
-
866
- # export_to_edge
867
- builder_exported = _prepare_for_llama_export (args ).export ()
868
- builder_exported .run_canonical_optimizations ()
869
- modelname = builder_exported .modelname
870
-
871
- if args .export_only :
872
- exit ()
873
-
874
- if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
875
- # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876
- args .xnnpack = True
877
-
878
- if args .xnnpack :
879
- builder = _to_edge_and_lower_llama_xnnpack (
880
- builder_exported ,
881
- modelname ,
882
- additional_passes ,
883
- pt2e_quant_params ,
884
- quantizers ,
885
- quant_dtype ,
886
- args ,
887
- )
888
- else :
889
- builder = _to_edge_and_lower_llama (
890
- builder_exported ,
891
- modelname ,
892
- additional_passes ,
893
- pt2e_quant_params ,
894
- quantizers ,
895
- quant_dtype ,
896
- args ,
897
- )
898
-
899
848
if args .profile_memory :
900
849
generate_memory_trace (builder .export_program , "memory_profile.json" )
901
850
@@ -917,6 +866,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
917
866
output_file = f"{ builder .output_dir } /{ modelname } .pte"
918
867
919
868
builder .save_to_pte (output_file )
869
+
920
870
return builder
921
871
922
872
0 commit comments