Skip to content

Revert #8501 and #8624 #8716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 41 additions & 91 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,62 +676,47 @@ def _validate_args(args):
)


def _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
) -> LLMEdgeManager: # noqa: C901
partitioners = []

# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))

modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack_extended_ops:
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"xnnpack_{modelname}"

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)

# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
if args.generate_etrecord:
raise NotImplementedError(
"export_llama does not support XNNPack and generating ETRecord at the moment."
)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)
if args.verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

return builder.to_executorch(passes=additional_passes)
builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

def _to_edge_and_lower_llama( # noqa: C901
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
):
builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()

modelname = builder_exported_to_edge.modelname

# to_backend
partitioners = []

# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
if (
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
) or (args.xnnpack):
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
)

# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
args.xnnpack = True
modelname = f"xnnpack_dq_{modelname}"

if args.xnnpack_extended_ops:
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"xnnpack_{modelname}"

if args.vulkan:
partitioners.append(
get_vulkan_partitioner(
Expand All @@ -746,6 +731,7 @@ def _to_edge_and_lower_llama( # noqa: C901
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())

if args.mps:
Expand Down Expand Up @@ -774,11 +760,13 @@ def _to_edge_and_lower_llama( # noqa: C901
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io

# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
_transform(builder_exported_to_edge.edge_manager.exported_program())

if args.num_sharding > 0:
model_sharding.split_graph(
builder_exported_to_edge.edge_manager.exported_program(),
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
)
Expand All @@ -804,15 +792,19 @@ def _to_edge_and_lower_llama( # noqa: C901
atten.head_dim,
)
)
# pyre-ignore
tag_quant_io(
builder_exported_to_edge.edge_manager.exported_program().graph_module,
partial(get_custom_quant_ios_dtype, cache_shape),
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
)

logging.info("Lowering model using following partitioner(s): ")
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -826,6 +818,7 @@ def _to_edge_and_lower_llama( # noqa: C901
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch(
Expand All @@ -847,55 +840,11 @@ def _to_edge_and_lower_llama( # noqa: C901
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch(passes=additional_passes)

return builder


def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()
builder_exported.run_canonical_optimizations()
modelname = builder_exported.modelname

if args.export_only:
exit()

if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
args.xnnpack = True

if args.xnnpack:
builder = _to_edge_and_lower_llama_xnnpack(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)
else:
builder = _to_edge_and_lower_llama(
builder_exported,
modelname,
additional_passes,
pt2e_quant_params,
quantizers,
quant_dtype,
args,
)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")

Expand All @@ -917,6 +866,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
output_file = f"{builder.output_dir}/{modelname}.pte"

builder.save_to_pte(output_file)

return builder


Expand Down
9 changes: 5 additions & 4 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ def quantize( # noqa C901
# Check for required args
if group_size is None:
raise Exception("For 8da4w quantization, group size must be specified.")
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_

quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
model = Int8DynActInt4WeightQuantizer(
precision=torch_dtype, groupsize=group_size
).quantize(model)

if verbose:
print("quantized model:", model)
Expand Down Expand Up @@ -662,7 +663,7 @@ def convert_for_runtime(self) -> nn.Module:
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict, assign=True)
self.mod.load_state_dict(model_updated_state_dict)
return self.mod


Expand Down
1 change: 1 addition & 0 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def export(self) -> "LlavaEdgeManager":
dynamic_shapes=dynamic_shape,
strict=False,
)
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
self.pre_autograd_graph_module = self.export_program.module()
return self

Expand Down
28 changes: 6 additions & 22 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir import EdgeProgramManager
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
Expand All @@ -39,7 +39,7 @@
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.export import export_for_training, ExportedProgram
from torch.export import export_for_training
from torch.nn.attention import SDPBackend

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -89,8 +89,8 @@ def __init__(
dynamic_shapes: Optional[Any] = None,
):
self.model = model
self.pre_autograd_exported_program: Optional[ExportedProgram] = None
self.pre_autograd_graph_module: Optional[torch.nn.Module] = None
# graph module returned from export()
self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
self.modelname = modelname
self.max_seq_len = max_seq_len
self.dtype = dtype
Expand Down Expand Up @@ -218,8 +218,8 @@ def export(self) -> "LLMEdgeManager":
kwargs=self.example_kwarg_inputs,
dynamic_shapes=dynamic_shape,
)
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
# `Module`.
self.pre_autograd_exported_program = exported_module
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
Expand Down Expand Up @@ -330,10 +330,7 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(
self.pre_autograd_graph_module, # pyre-ignore[6]
composed_quantizer,
)
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
Expand Down Expand Up @@ -433,19 +430,6 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_edge_transform_and_lower(
self, partitioners: Optional[List[Partitioner]]
) -> "LLMEdgeManager":
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")
edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
self.pre_autograd_exported_program,
partitioner=partitioners,
compile_config=edge_config,
)
return self

def to_executorch(
self, passes: Optional[List[ExportPass]] = None
) -> "LLMEdgeManager":
Expand Down
Loading