Skip to content

Commit b308544

Browse files
authored
Export recipe changes for export_llama integration
Differential Revision: D75628345 Pull Request resolved: #11227
1 parent cd49b58 commit b308544

File tree

3 files changed

+57
-20
lines changed

3 files changed

+57
-20
lines changed

export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ python_library(
1212
"//executorch/exir/backend:backend_api",
1313
"//executorch/exir:pass_manager",
1414
"//executorch/devtools/backend_debug:delegation_info",
15+
"//executorch/extension/export_util:export_util",
1516
]
1617
)
1718

export/export.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
import torch
55
from executorch.devtools.backend_debug import get_delegation_info
66
from executorch.exir._warnings import experimental
7+
from executorch.exir.backend.backend_api import validation_disabled
78
from executorch.exir.program import (
89
EdgeProgramManager,
910
ExecutorchProgramManager,
1011
to_edge_transform_and_lower,
1112
)
1213
from executorch.exir.schema import Program
14+
from executorch.extension.export_util.utils import save_pte_program
1315
from executorch.runtime import Runtime, Verification
1416
from tabulate import tabulate
1517
from torch import nn
1618
from torch.ao.quantization import allow_exported_model_train_eval
19+
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
1720
from torch.export import ExportedProgram
1821
from torchao.quantization import quantize_
1922
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
@@ -145,15 +148,15 @@ def run(
145148
model,
146149
self._example_inputs_dict[method_name][0],
147150
dynamic_shapes=dynamic_shapes,
151+
strict=True,
148152
)
149153

150154
# Apply pre-edge transform passes if available
151155
if self._pre_edge_transform_passes is not None:
152-
self._exported_program[method_name] = (
153-
self._pre_edge_transform_passes(
156+
for pre_edge_transform_pass in self._pre_edge_transform_passes:
157+
self._exported_program[method_name] = pre_edge_transform_pass(
154158
self._exported_program[method_name]
155159
)
156-
)
157160

158161
def get_artifacts(self) -> Dict[str, ExportedProgram]:
159162
"""
@@ -210,13 +213,14 @@ def run(
210213
self._constant_methods = transform_config.get("constant_methods", None)
211214

212215
# Process inputs
213-
self._edge_program_manager = to_edge_transform_and_lower(
214-
self._exported_program,
215-
partitioner=self._partitioners,
216-
transform_passes=self._transform_passes,
217-
constant_methods=self._constant_methods,
218-
compile_config=self._compile_config,
219-
)
216+
with validation_disabled():
217+
self._edge_program_manager = to_edge_transform_and_lower(
218+
self._exported_program,
219+
partitioner=self._partitioners,
220+
transform_passes=self._transform_passes,
221+
constant_methods=self._constant_methods,
222+
compile_config=self._compile_config,
223+
)
220224
self._delegation_info = get_delegation_info(
221225
self._edge_program_manager.exported_program().graph_module
222226
)
@@ -345,8 +349,8 @@ class QuantizeStage(Stage):
345349
Optional stage: Perform post-training quantization on the model.
346350
"""
347351

348-
def __init__(self, quantizer: Any) -> None:
349-
self._quantizer = quantizer
352+
def __init__(self, quantizers: Any) -> None:
353+
self._quantizers = quantizers
350354
self._quantized_models: Dict[str, nn.Module] = {}
351355
self._model_dict: Dict[str, nn.Module] = {}
352356
self._exported_program_dict: Dict[str, ExportedProgram] = {}
@@ -394,7 +398,8 @@ def run(
394398
model = exported_program.module()
395399

396400
# Prepare the model for quantization
397-
prepared_model = prepare_pt2e(model, self._quantizer) # type: ignore
401+
composed_quantizer = ComposableQuantizer(self._quantizers)
402+
prepared_model = prepare_pt2e(model, composed_quantizer) # type: ignore
398403

399404
# Allow the model to switch between train and eval modes
400405
allow_exported_model_train_eval(prepared_model)
@@ -546,9 +551,9 @@ def __init__(
546551

547552
# Create the quantize stage if a quantizer is provided
548553
if self._export_recipe.quantization_recipe is not None:
549-
quantizer = self._export_recipe.quantization_recipe.get_quantizer()
550-
if quantizer is not None:
551-
quantize_stage = QuantizeStage(quantizer=quantizer)
554+
quantizers = self._export_recipe.quantization_recipe.get_quantizers()
555+
if quantizers is not None:
556+
quantize_stage = QuantizeStage(quantizers=quantizers)
552557
self._pipeline.append(quantize_stage)
553558

554559
# Create the edge transform and lower stage
@@ -661,6 +666,22 @@ def get_executorch_program(self) -> Program:
661666
)
662667
return self._executorch_program_manager.executorch_program
663668

669+
def get_executorch_program_manager(self) -> ExecutorchProgramManager:
670+
"""
671+
Get the ExecutorchProgramManager.
672+
673+
Returns:
674+
The ExecutorchProgramManager
675+
676+
Raises:
677+
RuntimeError: If the executorch program manager is not initialized
678+
"""
679+
if self._executorch_program_manager is None:
680+
raise RuntimeError(
681+
"Executorch program manager is not initialized. Run export() first."
682+
)
683+
return self._executorch_program_manager
684+
664685
def get_pte_buffer(self) -> bytes:
665686
"""
666687
Get the PTE buffer as bytes.
@@ -677,6 +698,20 @@ def get_pte_buffer(self) -> bytes:
677698
)
678699
return self._executorch_program_manager.buffer
679700

701+
def save_to_pte(self, output_name: str) -> None:
702+
"""
703+
Save the model to a .pte file.
704+
705+
Args:
706+
output_name (Optional[str]): The name of the .pte file.
707+
"""
708+
assert output_name, "Need a valid output name"
709+
if self._executorch_program_manager is None:
710+
raise RuntimeError(
711+
"Executorch program manager is not initialized. Run export() first."
712+
)
713+
save_pte_program(self._executorch_program_manager, output_name)
714+
680715
def get_example_input(
681716
self, method_name: str = "forward"
682717
) -> Tuple[torch.Tensor, ...]:

export/recipe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,17 @@ class QuantizationRecipe:
4949
quantizer: Optional quantizer for model quantization
5050
"""
5151

52-
quantizer: Optional[Quantizer] = None
52+
quantizers: Optional[List[Quantizer]] = None
5353
ao_base_config: Optional[List[AOBaseConfig]] = None
5454

55-
def get_quantizer(self) -> Optional[Quantizer]:
55+
def get_quantizers(self) -> Optional[Quantizer]:
5656
"""
5757
Get the quantizer associated with this recipe.
5858
5959
Returns:
6060
The quantizer if one is set, otherwise None
6161
"""
62-
return self.quantizer
62+
return self.quantizers
6363

6464

6565
@experimental(
@@ -94,10 +94,11 @@ class ExportRecipe:
9494
)
9595
pre_edge_transform_passes: Optional[
9696
Callable[[ExportedProgram], ExportedProgram]
97+
| List[Callable[[ExportedProgram], ExportedProgram]]
9798
] = None
9899
edge_transform_passes: Optional[Sequence[PassType]] = None
99100
transform_check_ir_validity: bool = True
100-
partitioners: Optional[list[Partitioner]] = None
101+
partitioners: Optional[List[Partitioner]] = None
101102
executorch_backend_config: Optional[ExecutorchBackendConfig] = (
102103
None # pyre-ignore[11]: Type not defined
103104
)

0 commit comments

Comments
 (0)