Skip to content

Use to_edge_lower_and_transform for XNNPack #8624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 91 additions & 41 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,47 +676,62 @@ def _validate_args(args):
)


def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()

modelname = builder_exported_to_edge.modelname

# to_backend
def _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
) -> LLMEdgeManager: # noqa: C901
partitioners = []

# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
if (
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
) or (args.xnnpack):
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
)
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))

# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
args.xnnpack = True
modelname = f"xnnpack_dq_{modelname}"
modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack_extended_ops:
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"xnnpack_{modelname}"

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
if args.generate_etrecord:
raise NotImplementedError(
"export_llama does not support XNNPack and generating ETRecord at the moment."
)

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)

return builder.to_executorch(passes=additional_passes)


def _to_edge_and_lower_llama( # noqa: C901
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
):
builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()

# to_backend
partitioners = []
if args.vulkan:
partitioners.append(
get_vulkan_partitioner(
Expand All @@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())

if args.mps:
Expand Down Expand Up @@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io

# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
_transform(builder_exported_to_edge.edge_manager.exported_program())

if args.num_sharding > 0:
model_sharding.split_graph(
builder_exported_to_edge.edge_manager.exported_program(),
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
)
Expand All @@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
atten.head_dim,
)
)
# pyre-ignore
tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
partial(get_custom_quant_ios_dtype, cache_shape),
)

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch(
Expand All @@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch(passes=additional_passes)

return builder


def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname

if args.export_only:
exit()

if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
args.xnnpack = True

if args.xnnpack:
builder = _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)
else:
builder = _to_edge_and_lower_llama(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")

Expand All @@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
output_file = f"{builder.output_dir}/{modelname}.pte"

builder.save_to_pte(output_file)

return builder


Expand Down
1 change: 0 additions & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def export(self) -> "LlavaEdgeManager":
dynamic_shapes=dynamic_shape,
strict=False,
)
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
self.pre_autograd_graph_module = self.export_program.module()
return self

Expand Down
28 changes: 22 additions & 6 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
Expand All @@ -39,7 +39,7 @@
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.export import export_for_training
from torch.export import export_for_training, ExportedProgram
from torch.nn.attention import SDPBackend

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -89,8 +89,8 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
):
self.model = model
# graph module returned from export()
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
self.modelname = modelname
self.max_seq_len = max_seq_len
self.dtype = dtype
Expand Down Expand Up @@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager":
kwargs=self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
)
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
# `Module`.
self.pre_autograd_exported_program = exported_module
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
Expand Down Expand Up @@ -330,7 +330,10 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
m = prepare_pt2e(
self.pre_autograd_graph_module, # pyre-ignore[6]
composed_quantizer,
)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
Expand Down Expand Up @@ -430,6 +433,19 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_edge_transform_and_lower(
self, partitioners: Optional[List[Partitioner]]
) -> "LLMEdgeManager":
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")
edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
self.pre_autograd_exported_program,
partitioner=partitioners,
compile_config=edge_config,
)
return self

def to_executorch(
self, passes: Optional[List[ExportPass]] = None
) -> "LLMEdgeManager":
Expand Down
Loading