Skip to content

Export recipe changes for export_llama integration #11227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions export/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ python_library(
"//executorch/exir/backend:backend_api",
"//executorch/exir:pass_manager",
"//executorch/devtools/backend_debug:delegation_info",
"//executorch/extension/export_util:export_util",
]
)

Expand Down
67 changes: 51 additions & 16 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import torch
from executorch.devtools.backend_debug import get_delegation_info
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.program import (
EdgeProgramManager,
ExecutorchProgramManager,
to_edge_transform_and_lower,
)
from executorch.exir.schema import Program
from executorch.extension.export_util.utils import save_pte_program
from executorch.runtime import Runtime, Verification
from tabulate import tabulate
from torch import nn
from torch.ao.quantization import allow_exported_model_train_eval
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.export import ExportedProgram
from torchao.quantization import quantize_
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
Expand Down Expand Up @@ -145,15 +148,15 @@ def run(
model,
self._example_inputs_dict[method_name][0],
dynamic_shapes=dynamic_shapes,
strict=True,
)

# Apply pre-edge transform passes if available
if self._pre_edge_transform_passes is not None:
self._exported_program[method_name] = (
self._pre_edge_transform_passes(
for pre_edge_transform_pass in self._pre_edge_transform_passes:
self._exported_program[method_name] = pre_edge_transform_pass(
self._exported_program[method_name]
)
)

def get_artifacts(self) -> Dict[str, ExportedProgram]:
"""
Expand Down Expand Up @@ -210,13 +213,14 @@ def run(
self._constant_methods = transform_config.get("constant_methods", None)

# Process inputs
self._edge_program_manager = to_edge_transform_and_lower(
self._exported_program,
partitioner=self._partitioners,
transform_passes=self._transform_passes,
constant_methods=self._constant_methods,
compile_config=self._compile_config,
)
with validation_disabled():
self._edge_program_manager = to_edge_transform_and_lower(
self._exported_program,
partitioner=self._partitioners,
transform_passes=self._transform_passes,
constant_methods=self._constant_methods,
compile_config=self._compile_config,
)
self._delegation_info = get_delegation_info(
self._edge_program_manager.exported_program().graph_module
)
Expand Down Expand Up @@ -345,8 +349,8 @@ class QuantizeStage(Stage):
Optional stage: Perform post-training quantization on the model.
"""

def __init__(self, quantizer: Any) -> None:
self._quantizer = quantizer
def __init__(self, quantizers: Any) -> None:
self._quantizers = quantizers
self._quantized_models: Dict[str, nn.Module] = {}
self._model_dict: Dict[str, nn.Module] = {}
self._exported_program_dict: Dict[str, ExportedProgram] = {}
Expand Down Expand Up @@ -394,7 +398,8 @@ def run(
model = exported_program.module()

# Prepare the model for quantization
prepared_model = prepare_pt2e(model, self._quantizer) # type: ignore
composed_quantizer = ComposableQuantizer(self._quantizers)
prepared_model = prepare_pt2e(model, composed_quantizer) # type: ignore

# Allow the model to switch between train and eval modes
allow_exported_model_train_eval(prepared_model)
Expand Down Expand Up @@ -546,9 +551,9 @@ def __init__(

# Create the quantize stage if a quantizer is provided
if self._export_recipe.quantization_recipe is not None:
quantizer = self._export_recipe.quantization_recipe.get_quantizer()
if quantizer is not None:
quantize_stage = QuantizeStage(quantizer=quantizer)
quantizers = self._export_recipe.quantization_recipe.get_quantizers()
if quantizers is not None:
quantize_stage = QuantizeStage(quantizers=quantizers)
self._pipeline.append(quantize_stage)

# Create the edge transform and lower stage
Expand Down Expand Up @@ -661,6 +666,22 @@ def get_executorch_program(self) -> Program:
)
return self._executorch_program_manager.executorch_program

def get_executorch_program_manager(self) -> ExecutorchProgramManager:
"""
Get the ExecutorchProgramManager.

Returns:
The ExecutorchProgramManager

Raises:
RuntimeError: If the executorch program manager is not initialized
"""
if self._executorch_program_manager is None:
raise RuntimeError(
"Executorch program manager is not initialized. Run export() first."
)
return self._executorch_program_manager

def get_pte_buffer(self) -> bytes:
"""
Get the PTE buffer as bytes.
Expand All @@ -677,6 +698,20 @@ def get_pte_buffer(self) -> bytes:
)
return self._executorch_program_manager.buffer

def save_to_pte(self, output_name: str) -> None:
"""
Save the model to a .pte file.

Args:
output_name (Optional[str]): The name of the .pte file.
"""
assert output_name, "Need a valid output name"
if self._executorch_program_manager is None:
raise RuntimeError(
"Executorch program manager is not initialized. Run export() first."
)
save_pte_program(self._executorch_program_manager, output_name)

def get_example_input(
self, method_name: str = "forward"
) -> Tuple[torch.Tensor, ...]:
Expand Down
9 changes: 5 additions & 4 deletions export/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ class QuantizationRecipe:
quantizer: Optional quantizer for model quantization
"""

quantizer: Optional[Quantizer] = None
quantizers: Optional[List[Quantizer]] = None
ao_base_config: Optional[List[AOBaseConfig]] = None

def get_quantizer(self) -> Optional[Quantizer]:
def get_quantizers(self) -> Optional[Quantizer]:
"""
Get the quantizer associated with this recipe.

Returns:
The quantizer if one is set, otherwise None
"""
return self.quantizer
return self.quantizers


@experimental(
Expand Down Expand Up @@ -94,10 +94,11 @@ class ExportRecipe:
)
pre_edge_transform_passes: Optional[
Callable[[ExportedProgram], ExportedProgram]
| List[Callable[[ExportedProgram], ExportedProgram]]
] = None
edge_transform_passes: Optional[Sequence[PassType]] = None
transform_check_ir_validity: bool = True
partitioners: Optional[list[Partitioner]] = None
partitioners: Optional[List[Partitioner]] = None
executorch_backend_config: Optional[ExecutorchBackendConfig] = (
None # pyre-ignore[11]: Type not defined
)
Expand Down
Loading