Skip to content

Commit 8a95288

Browse files
committed
Revert "Use to_edge_lower_and_transform for XNNPack (#8624)"
This reverts commit b5344c1.
1 parent adb897c commit 8a95288

File tree

3 files changed

+48
-113
lines changed

3 files changed

+48
-113
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 41 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -676,62 +676,47 @@ def _validate_args(args):
676676
)
677677

678678

679-
def _to_edge_and_lower_llama_xnnpack(
680-
builder_exported,
681-
modelname,
682-
additional_passes,
683-
pt2e_quant_params,
684-
quantizers,
685-
quant_dtype,
686-
args,
687-
) -> LLMEdgeManager: # noqa: C901
688-
partitioners = []
689-
690-
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
691-
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
692-
693-
modelname = f"xnnpack_dq_{modelname}"
694-
695-
if args.xnnpack_extended_ops:
696-
partitioners.append(
697-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
698-
)
699-
modelname = f"xnnpack_{modelname}"
700-
701-
logging.info("Lowering model using following partitioner(s): ")
702-
for partitioner in partitioners:
703-
logging.info(f"--> {partitioner.__class__.__name__}")
679+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
680+
_validate_args(args)
704681

705-
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706-
if args.generate_etrecord:
707-
raise NotImplementedError(
708-
"export_llama does not support XNNPack and generating ETRecord at the moment."
709-
)
682+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
710683

711-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
712-
partitioners
713-
)
714-
if args.verbose:
715-
print_delegation_info(builder.edge_manager.exported_program().graph_module)
684+
# export_to_edge
685+
builder_exported = _prepare_for_llama_export(args).export()
716686

717-
return builder.to_executorch(passes=additional_passes)
687+
builder_exported.run_canonical_optimizations()
718688

689+
if args.export_only:
690+
exit()
719691

720-
def _to_edge_and_lower_llama( # noqa: C901
721-
builder_exported,
722-
modelname,
723-
additional_passes,
724-
pt2e_quant_params,
725-
quantizers,
726-
quant_dtype,
727-
args,
728-
):
729692
builder_exported_to_edge = builder_exported.pt2e_quantize(
730693
quantizers
731694
).export_to_edge()
732695

696+
modelname = builder_exported_to_edge.modelname
697+
733698
# to_backend
734699
partitioners = []
700+
701+
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702+
if (
703+
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
704+
) or (args.xnnpack):
705+
partitioners.append(
706+
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
707+
)
708+
709+
# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710+
args.xnnpack = True
711+
modelname = f"xnnpack_dq_{modelname}"
712+
713+
if args.xnnpack_extended_ops:
714+
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
715+
partitioners.append(
716+
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
717+
)
718+
modelname = f"xnnpack_{modelname}"
719+
735720
if args.vulkan:
736721
partitioners.append(
737722
get_vulkan_partitioner(
@@ -746,6 +731,7 @@ def _to_edge_and_lower_llama( # noqa: C901
746731
modelname = f"vulkan_{modelname}"
747732

748733
# Need to remove asserts from the graph to prevent graph breaks
734+
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
749735
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
750736

751737
if args.mps:
@@ -774,11 +760,13 @@ def _to_edge_and_lower_llama( # noqa: C901
774760
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
775761
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
776762

763+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
777764
_transform(builder_exported_to_edge.edge_manager.exported_program())
778765

779766
if args.num_sharding > 0:
780767
model_sharding.split_graph(
781768
builder_exported_to_edge.edge_manager.exported_program(),
769+
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
782770
builder_exported_to_edge.metadata["get_n_layers"],
783771
shares=args.num_sharding,
784772
)
@@ -804,15 +792,19 @@ def _to_edge_and_lower_llama( # noqa: C901
804792
atten.head_dim,
805793
)
806794
)
795+
# pyre-ignore
807796
tag_quant_io(
808797
builder_exported_to_edge.edge_manager.exported_program().graph_module,
809-
partial(get_custom_quant_ios_dtype, cache_shape),
798+
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
810799
)
811800

812801
logging.info("Lowering model using following partitioner(s): ")
813802
for partitioner in partitioners:
814803
logging.info(f"--> {partitioner.__class__.__name__}")
815804

805+
additional_passes = []
806+
if args.model in TORCHTUNE_DEFINED_MODELS:
807+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
816808
if args.generate_etrecord:
817809
if not builder_exported_to_edge.edge_manager:
818810
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -826,6 +818,7 @@ def _to_edge_and_lower_llama( # noqa: C901
826818
if args.num_sharding > 0 and args.qnn:
827819
from executorch.backends.qualcomm.utils.utils import canonicalize_program
828820

821+
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
829822
canonicalize_program(builder.edge_manager.exported_program())
830823

831824
builder = builder.to_executorch(
@@ -847,55 +840,11 @@ def _to_edge_and_lower_llama( # noqa: C901
847840
if args.num_sharding > 0 and args.qnn:
848841
from executorch.backends.qualcomm.utils.utils import canonicalize_program
849842

843+
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
850844
canonicalize_program(builder.edge_manager.exported_program())
851845

852846
builder = builder.to_executorch(passes=additional_passes)
853847

854-
return builder
855-
856-
857-
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
858-
_validate_args(args)
859-
860-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
861-
862-
additional_passes = []
863-
if args.model in TORCHTUNE_DEFINED_MODELS:
864-
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
865-
866-
# export_to_edge
867-
builder_exported = _prepare_for_llama_export(args).export()
868-
builder_exported.run_canonical_optimizations()
869-
modelname = builder_exported.modelname
870-
871-
if args.export_only:
872-
exit()
873-
874-
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
875-
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876-
args.xnnpack = True
877-
878-
if args.xnnpack:
879-
builder = _to_edge_and_lower_llama_xnnpack(
880-
builder_exported,
881-
modelname,
882-
additional_passes,
883-
pt2e_quant_params,
884-
quantizers,
885-
quant_dtype,
886-
args,
887-
)
888-
else:
889-
builder = _to_edge_and_lower_llama(
890-
builder_exported,
891-
modelname,
892-
additional_passes,
893-
pt2e_quant_params,
894-
quantizers,
895-
quant_dtype,
896-
args,
897-
)
898-
899848
if args.profile_memory:
900849
generate_memory_trace(builder.export_program, "memory_profile.json")
901850

@@ -917,6 +866,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
917866
output_file = f"{builder.output_dir}/{modelname}.pte"
918867

919868
builder.save_to_pte(output_file)
869+
920870
return builder
921871

922872

examples/models/llava/export_llava.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def export(self) -> "LlavaEdgeManager":
6767
dynamic_shapes=dynamic_shape,
6868
strict=False,
6969
)
70+
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
7071
self.pre_autograd_graph_module = self.export_program.module()
7172
return self
7273

extension/llm/export/builder.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
DuplicateDynamicQuantChainPass,
2222
)
2323
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
24-
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
24+
from executorch.exir import EdgeProgramManager
2525
from executorch.exir.backend.partitioner import Partitioner
2626

2727
from executorch.exir.backend.utils import format_delegated_graph
@@ -39,7 +39,7 @@
3939
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
4040
from torch.ao.quantization.quantizer import Quantizer
4141
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
42-
from torch.export import export_for_training, ExportedProgram
42+
from torch.export import export_for_training
4343
from torch.nn.attention import SDPBackend
4444

4545
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -89,8 +89,8 @@ def __init__(
8989
dynamic_shapes: Optional[Any] = None,
9090
):
9191
self.model = model
92-
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
93-
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
92+
# graph module returned from export()
93+
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
9494
self.modelname = modelname
9595
self.max_seq_len = max_seq_len
9696
self.dtype = dtype
@@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager":
218218
kwargs=self.example_kwarg_inputs,
219219
dynamic_shapes=dynamic_shape,
220220
)
221+
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
221222
# `Module`.
222-
self.pre_autograd_exported_program = exported_module
223223
self.pre_autograd_graph_module = exported_module.module()
224224
if hasattr(self.args, "export_only") and self.args.export_only:
225225
torch.export.save(exported_module, self.args.output_name)
@@ -330,10 +330,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
330330
assert (
331331
self.pre_autograd_graph_module is not None
332332
), "Please run export() first"
333-
m = prepare_pt2e(
334-
self.pre_autograd_graph_module, # pyre-ignore[6]
335-
composed_quantizer,
336-
)
333+
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
337334
logging.info(
338335
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
339336
)
@@ -433,19 +430,6 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
433430

434431
return self
435432

436-
def to_edge_transform_and_lower(
437-
self, partitioners: Optional[List[Partitioner]]
438-
) -> "LLMEdgeManager":
439-
if partitioners is None:
440-
logging.info("No partitioner provided, skipping backend lowering...")
441-
edge_config = self._get_edge_config()
442-
self.edge_manager = to_edge_transform_and_lower(
443-
self.pre_autograd_exported_program,
444-
partitioner=partitioners,
445-
compile_config=edge_config,
446-
)
447-
return self
448-
449433
def to_executorch(
450434
self, passes: Optional[List[ExportPass]] = None
451435
) -> "LLMEdgeManager":

0 commit comments

Comments
 (0)