Skip to content

Commit 1b0c5e4

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Use to_edge_lower_and_transform for XNNPack (#8624)
Summary: Use `to_edge_transform_and_lower` in `export_llama` for XNNPack. As part of these changes, this also means that you cannot specify multiple backends in `export_llama` in the args, although I'm not sure if that is happening anywhere at the moment. Closes #8621 Performance regression benchmarking for xnnpack (on android) vs. past 3 days: <img width="1427" alt="Screenshot 2025-02-24 at 11 39 52 AM" src="https://github.com/user-attachments/assets/1640cf2c-a579-491f-8940-7ccfbe464903" /> These benchmark numbers also normally fluctuate a bit across runs and these differences are within the usual fluctuation ranges. Test Plan: See if CI passes Differential Revision: D70124742 Pulled By: jackzhxng
1 parent 7103bb3 commit 1b0c5e4

File tree

2 files changed

+110
-47
lines changed

2 files changed

+110
-47
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -674,47 +674,62 @@ def _validate_args(args):
674674
)
675675

676676

677-
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
678-
_validate_args(args)
679-
680-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
681-
682-
# export_to_edge
683-
builder_exported = _prepare_for_llama_export(args).export()
684-
685-
builder_exported.run_canonical_optimizations()
686-
687-
if args.export_only:
688-
exit()
689-
690-
builder_exported_to_edge = builder_exported.pt2e_quantize(
691-
quantizers
692-
).export_to_edge()
693-
694-
modelname = builder_exported_to_edge.modelname
695-
696-
# to_backend
677+
def _to_edge_and_lower_llama_xnnpack(
678+
builder_exported,
679+
modelname,
680+
additional_passes,
681+
pt2e_quant_params,
682+
quantizers,
683+
quant_dtype,
684+
args,
685+
) -> LLMEdgeManager: # noqa: C901
697686
partitioners = []
698687

699688
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
700-
if (
701-
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
702-
) or (args.xnnpack):
703-
partitioners.append(
704-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
705-
)
689+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
706690

707-
# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
708-
args.xnnpack = True
709-
modelname = f"xnnpack_dq_{modelname}"
691+
modelname = f"xnnpack_dq_{modelname}"
710692

711693
if args.xnnpack_extended_ops:
712-
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
713694
partitioners.append(
714695
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
715696
)
716697
modelname = f"xnnpack_{modelname}"
717698

699+
logging.info("Lowering model using following partitioner(s): ")
700+
for partitioner in partitioners:
701+
logging.info(f"--> {partitioner.__class__.__name__}")
702+
703+
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
704+
if args.generate_etrecord:
705+
raise NotImplementedError(
706+
"export_llama does not support XNNPack and generating ETRecord at the moment."
707+
)
708+
709+
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
710+
partitioners
711+
)
712+
if args.verbose:
713+
print_delegation_info(builder.edge_manager.exported_program().graph_module)
714+
715+
return builder.to_executorch(passes=additional_passes)
716+
717+
718+
def _to_edge_and_lower_llama( # noqa: C901
719+
builder_exported,
720+
modelname,
721+
additional_passes,
722+
pt2e_quant_params,
723+
quantizers,
724+
quant_dtype,
725+
args,
726+
):
727+
builder_exported_to_edge = builder_exported.pt2e_quantize(
728+
quantizers
729+
).export_to_edge()
730+
731+
# to_backend
732+
partitioners = []
718733
if args.vulkan:
719734
partitioners.append(
720735
get_vulkan_partitioner(
@@ -729,7 +744,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
729744
modelname = f"vulkan_{modelname}"
730745

731746
# Need to remove asserts from the graph to prevent graph breaks
732-
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
733747
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
734748

735749
if args.mps:
@@ -758,13 +772,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
758772
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
759773
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
760774

761-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
762775
_transform(builder_exported_to_edge.edge_manager.exported_program())
763776

764777
if args.num_sharding > 0:
765778
model_sharding.split_graph(
766779
builder_exported_to_edge.edge_manager.exported_program(),
767-
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
768780
builder_exported_to_edge.metadata["get_n_layers"],
769781
shares=args.num_sharding,
770782
)
@@ -790,19 +802,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
790802
atten.head_dim,
791803
)
792804
)
793-
# pyre-ignore
794805
tag_quant_io(
795806
builder_exported_to_edge.edge_manager.exported_program().graph_module,
796-
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
807+
partial(get_custom_quant_ios_dtype, cache_shape),
797808
)
798809

799810
logging.info("Lowering model using following partitioner(s): ")
800811
for partitioner in partitioners:
801812
logging.info(f"--> {partitioner.__class__.__name__}")
802813

803-
additional_passes = []
804-
if args.model in TORCHTUNE_DEFINED_MODELS:
805-
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
806814
if args.generate_etrecord:
807815
if not builder_exported_to_edge.edge_manager:
808816
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -816,7 +824,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
816824
if args.num_sharding > 0 and args.qnn:
817825
from executorch.backends.qualcomm.utils.utils import canonicalize_program
818826

819-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
820827
canonicalize_program(builder.edge_manager.exported_program())
821828

822829
builder = builder.to_executorch(
@@ -838,11 +845,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
838845
if args.num_sharding > 0 and args.qnn:
839846
from executorch.backends.qualcomm.utils.utils import canonicalize_program
840847

841-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
842848
canonicalize_program(builder.edge_manager.exported_program())
843849

844850
builder = builder.to_executorch(passes=additional_passes)
845851

852+
return builder
853+
854+
855+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
856+
_validate_args(args)
857+
858+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
859+
860+
additional_passes = []
861+
if args.model in TORCHTUNE_DEFINED_MODELS:
862+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
863+
864+
# export_to_edge
865+
builder_exported = _prepare_for_llama_export(args).export()
866+
builder_exported.run_canonical_optimizations()
867+
modelname = builder_exported.modelname
868+
869+
if args.export_only:
870+
exit()
871+
872+
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
873+
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
874+
args.xnnpack = True
875+
876+
if args.xnnpack:
877+
builder = _to_edge_and_lower_llama_xnnpack(
878+
builder_exported,
879+
modelname,
880+
additional_passes,
881+
pt2e_quant_params,
882+
quantizers,
883+
quant_dtype,
884+
args,
885+
)
886+
else:
887+
builder = _to_edge_and_lower_llama(
888+
builder_exported,
889+
modelname,
890+
additional_passes,
891+
pt2e_quant_params,
892+
quantizers,
893+
quant_dtype,
894+
args,
895+
)
896+
846897
if args.profile_memory:
847898
generate_memory_trace(builder.export_program, "memory_profile.json")
848899

@@ -864,7 +915,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
864915
output_file = f"{builder.output_dir}/{modelname}.pte"
865916

866917
builder.save_to_pte(output_file)
867-
868918
return builder
869919

870920

extension/llm/export/builder.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DuplicateDynamicQuantChainPass,
2222
)
2323
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
24-
from executorch.exir import EdgeProgramManager
24+
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2525
from executorch.exir.backend.partitioner import Partitioner
2626

2727
from executorch.exir.backend.utils import format_delegated_graph
@@ -39,7 +39,7 @@
3939
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torch.ao.quantization.quantizer import Quantizer
4141
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
42-
from torch.export import export_for_training
42+
from torch.export import export_for_training, ExportedProgram
4343
from torch.nn.attention import SDPBackend
4444

4545
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -89,8 +89,8 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
# graph module returned from export()
93-
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
92+
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
93+
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
9494
self.modelname = modelname
9595
self.max_seq_len = max_seq_len
9696
self.dtype = dtype
@@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager":
218218
kwargs=self.example_kwarg_inputs,
219219
dynamic_shapes=dynamic_shape,
220220
)
221-
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
222221
# `Module`.
222+
self.pre_autograd_exported_program = exported_module
223223
self.pre_autograd_graph_module = exported_module.module()
224224
if hasattr(self.args, "export_only") and self.args.export_only:
225225
torch.export.save(exported_module, self.args.output_name)
@@ -330,7 +330,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
330330
assert (
331331
self.pre_autograd_graph_module is not None
332332
), "Please run export() first"
333-
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
333+
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) # pyre-ignore[6]: In call `prepare_pt2e`, for 1st positional argument, expected `GraphModule` but got `Module`.
334334
logging.info(
335335
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
336336
)
@@ -430,6 +430,19 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
430430

431431
return self
432432

433+
def to_edge_transform_and_lower(
434+
self, partitioners: Optional[List[Partitioner]]
435+
) -> "LLMEdgeManager":
436+
if partitioners is None:
437+
logging.info("No partitioner provided, skipping backend lowering...")
438+
edge_config = self._get_edge_config()
439+
self.edge_manager = to_edge_transform_and_lower(
440+
self.pre_autograd_exported_program,
441+
partitioner=partitioners,
442+
compile_config=edge_config,
443+
)
444+
return self
445+
433446
def to_executorch(
434447
self, passes: Optional[List[ExportPass]] = None
435448
) -> "LLMEdgeManager":

0 commit comments

Comments
 (0)