@@ -674,47 +674,62 @@ def _validate_args(args):
674
674
)
675
675
676
676
677
- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
678
- _validate_args (args )
679
-
680
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
681
-
682
- # export_to_edge
683
- builder_exported = _prepare_for_llama_export (args ).export ()
684
-
685
- builder_exported .run_canonical_optimizations ()
686
-
687
- if args .export_only :
688
- exit ()
689
-
690
- builder_exported_to_edge = builder_exported .pt2e_quantize (
691
- quantizers
692
- ).export_to_edge ()
693
-
694
- modelname = builder_exported_to_edge .modelname
695
-
696
- # to_backend
677
+ def _to_edge_and_lower_llama_xnnpack (
678
+ builder_exported ,
679
+ modelname ,
680
+ additional_passes ,
681
+ pt2e_quant_params ,
682
+ quantizers ,
683
+ quant_dtype ,
684
+ args ,
685
+ ) -> LLMEdgeManager : # noqa: C901
697
686
partitioners = []
698
687
699
688
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
700
- if (
701
- pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None
702
- ) or (args .xnnpack ):
703
- partitioners .append (
704
- get_xnnpack_partitioner (dynamic_quant_only_partitioner = True )
705
- )
689
+ partitioners .append (get_xnnpack_partitioner (dynamic_quant_only_partitioner = True ))
706
690
707
- # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
708
- args .xnnpack = True
709
- modelname = f"xnnpack_dq_{ modelname } "
691
+ modelname = f"xnnpack_dq_{ modelname } "
710
692
711
693
if args .xnnpack_extended_ops :
712
- assert args .xnnpack , "xnnpack_extended_ops requires xnnpack to be enabled"
713
694
partitioners .append (
714
695
get_xnnpack_partitioner (dynamic_quant_only_partitioner = False )
715
696
)
716
697
modelname = f"xnnpack_{ modelname } "
717
698
699
+ logging .info ("Lowering model using following partitioner(s): " )
700
+ for partitioner in partitioners :
701
+ logging .info (f"--> { partitioner .__class__ .__name__ } " )
702
+
703
+ # TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
704
+ if args .generate_etrecord :
705
+ raise NotImplementedError (
706
+ "export_llama does not support XNNPack and generating ETRecord at the moment."
707
+ )
708
+
709
+ builder = builder_exported .pt2e_quantize (quantizers ).to_edge_transform_and_lower (
710
+ partitioners
711
+ )
712
+ if args .verbose :
713
+ print_delegation_info (builder .edge_manager .exported_program ().graph_module )
714
+
715
+ return builder .to_executorch (passes = additional_passes )
716
+
717
+
718
+ def _to_edge_and_lower_llama ( # noqa: C901
719
+ builder_exported ,
720
+ modelname ,
721
+ additional_passes ,
722
+ pt2e_quant_params ,
723
+ quantizers ,
724
+ quant_dtype ,
725
+ args ,
726
+ ):
727
+ builder_exported_to_edge = builder_exported .pt2e_quantize (
728
+ quantizers
729
+ ).export_to_edge ()
730
+
731
+ # to_backend
732
+ partitioners = []
718
733
if args .vulkan :
719
734
partitioners .append (
720
735
get_vulkan_partitioner (
@@ -729,7 +744,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
729
744
modelname = f"vulkan_{ modelname } "
730
745
731
746
# Need to remove asserts from the graph to prevent graph breaks
732
- # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
733
747
remove_asserts (builder_exported_to_edge .edge_manager .exported_program ())
734
748
735
749
if args .mps :
@@ -758,13 +772,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
758
772
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
759
773
from executorch .backends .qualcomm .utils .utils import _transform , tag_quant_io
760
774
761
- # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
762
775
_transform (builder_exported_to_edge .edge_manager .exported_program ())
763
776
764
777
if args .num_sharding > 0 :
765
778
model_sharding .split_graph (
766
779
builder_exported_to_edge .edge_manager .exported_program (),
767
- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
768
780
builder_exported_to_edge .metadata ["get_n_layers" ],
769
781
shares = args .num_sharding ,
770
782
)
@@ -790,19 +802,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
790
802
atten .head_dim ,
791
803
)
792
804
)
793
- # pyre-ignore
794
805
tag_quant_io (
795
806
builder_exported_to_edge .edge_manager .exported_program ().graph_module ,
796
- partial (get_custom_quant_ios_dtype , cache_shape ), # pyre-ignore
807
+ partial (get_custom_quant_ios_dtype , cache_shape ),
797
808
)
798
809
799
810
logging .info ("Lowering model using following partitioner(s): " )
800
811
for partitioner in partitioners :
801
812
logging .info (f"--> { partitioner .__class__ .__name__ } " )
802
813
803
- additional_passes = []
804
- if args .model in TORCHTUNE_DEFINED_MODELS :
805
- additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
806
814
if args .generate_etrecord :
807
815
if not builder_exported_to_edge .edge_manager :
808
816
raise ValueError ("Unable to generate etrecord due to missing edge manager." )
@@ -816,7 +824,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
816
824
if args .num_sharding > 0 and args .qnn :
817
825
from executorch .backends .qualcomm .utils .utils import canonicalize_program
818
826
819
- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
820
827
canonicalize_program (builder .edge_manager .exported_program ())
821
828
822
829
builder = builder .to_executorch (
@@ -838,11 +845,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
838
845
if args .num_sharding > 0 and args .qnn :
839
846
from executorch .backends .qualcomm .utils .utils import canonicalize_program
840
847
841
- # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
842
848
canonicalize_program (builder .edge_manager .exported_program ())
843
849
844
850
builder = builder .to_executorch (passes = additional_passes )
845
851
852
+ return builder
853
+
854
+
855
+ def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
856
+ _validate_args (args )
857
+
858
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
859
+
860
+ additional_passes = []
861
+ if args .model in TORCHTUNE_DEFINED_MODELS :
862
+ additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
863
+
864
+ # export_to_edge
865
+ builder_exported = _prepare_for_llama_export (args ).export ()
866
+ builder_exported .run_canonical_optimizations ()
867
+ modelname = builder_exported .modelname
868
+
869
+ if args .export_only :
870
+ exit ()
871
+
872
+ if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
873
+ # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
874
+ args .xnnpack = True
875
+
876
+ if args .xnnpack :
877
+ builder = _to_edge_and_lower_llama_xnnpack (
878
+ builder_exported ,
879
+ modelname ,
880
+ additional_passes ,
881
+ pt2e_quant_params ,
882
+ quantizers ,
883
+ quant_dtype ,
884
+ args ,
885
+ )
886
+ else :
887
+ builder = _to_edge_and_lower_llama (
888
+ builder_exported ,
889
+ modelname ,
890
+ additional_passes ,
891
+ pt2e_quant_params ,
892
+ quantizers ,
893
+ quant_dtype ,
894
+ args ,
895
+ )
896
+
846
897
if args .profile_memory :
847
898
generate_memory_trace (builder .export_program , "memory_profile.json" )
848
899
@@ -864,7 +915,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
864
915
output_file = f"{ builder .output_dir } /{ modelname } .pte"
865
916
866
917
builder .save_to_pte (output_file )
867
-
868
918
return builder
869
919
870
920
0 commit comments