4
4
import torch
5
5
from executorch .devtools .backend_debug import get_delegation_info
6
6
from executorch .exir ._warnings import experimental
7
+ from executorch .exir .backend .backend_api import validation_disabled
7
8
from executorch .exir .program import (
8
9
EdgeProgramManager ,
9
10
ExecutorchProgramManager ,
10
11
to_edge_transform_and_lower ,
11
12
)
12
13
from executorch .exir .schema import Program
14
+ from executorch .extension .export_util .utils import save_pte_program
13
15
from executorch .runtime import Runtime , Verification
14
16
from tabulate import tabulate
15
17
from torch import nn
16
18
from torch .ao .quantization import allow_exported_model_train_eval
19
+ from torch .ao .quantization .quantizer .composable_quantizer import ComposableQuantizer
17
20
from torch .export import ExportedProgram
18
21
from torchao .quantization import quantize_
19
22
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
@@ -145,15 +148,15 @@ def run(
145
148
model ,
146
149
self ._example_inputs_dict [method_name ][0 ],
147
150
dynamic_shapes = dynamic_shapes ,
151
+ strict = True ,
148
152
)
149
153
150
154
# Apply pre-edge transform passes if available
151
155
if self ._pre_edge_transform_passes is not None :
152
- self . _exported_program [ method_name ] = (
153
- self ._pre_edge_transform_passes (
156
+ for pre_edge_transform_pass in self . _pre_edge_transform_passes :
157
+ self ._exported_program [ method_name ] = pre_edge_transform_pass (
154
158
self ._exported_program [method_name ]
155
159
)
156
- )
157
160
158
161
def get_artifacts (self ) -> Dict [str , ExportedProgram ]:
159
162
"""
@@ -210,13 +213,14 @@ def run(
210
213
self ._constant_methods = transform_config .get ("constant_methods" , None )
211
214
212
215
# Process inputs
213
- self ._edge_program_manager = to_edge_transform_and_lower (
214
- self ._exported_program ,
215
- partitioner = self ._partitioners ,
216
- transform_passes = self ._transform_passes ,
217
- constant_methods = self ._constant_methods ,
218
- compile_config = self ._compile_config ,
219
- )
216
+ with validation_disabled ():
217
+ self ._edge_program_manager = to_edge_transform_and_lower (
218
+ self ._exported_program ,
219
+ partitioner = self ._partitioners ,
220
+ transform_passes = self ._transform_passes ,
221
+ constant_methods = self ._constant_methods ,
222
+ compile_config = self ._compile_config ,
223
+ )
220
224
self ._delegation_info = get_delegation_info (
221
225
self ._edge_program_manager .exported_program ().graph_module
222
226
)
@@ -345,8 +349,8 @@ class QuantizeStage(Stage):
345
349
Optional stage: Perform post-training quantization on the model.
346
350
"""
347
351
348
- def __init__ (self , quantizer : Any ) -> None :
349
- self ._quantizer = quantizer
352
+ def __init__ (self , quantizers : Any ) -> None :
353
+ self ._quantizers = quantizers
350
354
self ._quantized_models : Dict [str , nn .Module ] = {}
351
355
self ._model_dict : Dict [str , nn .Module ] = {}
352
356
self ._exported_program_dict : Dict [str , ExportedProgram ] = {}
@@ -394,7 +398,8 @@ def run(
394
398
model = exported_program .module ()
395
399
396
400
# Prepare the model for quantization
397
- prepared_model = prepare_pt2e (model , self ._quantizer ) # type: ignore
401
+ composed_quantizer = ComposableQuantizer (self ._quantizers )
402
+ prepared_model = prepare_pt2e (model , composed_quantizer ) # type: ignore
398
403
399
404
# Allow the model to switch between train and eval modes
400
405
allow_exported_model_train_eval (prepared_model )
@@ -546,9 +551,9 @@ def __init__(
546
551
547
552
# Create the quantize stage if a quantizer is provided
548
553
if self ._export_recipe .quantization_recipe is not None :
549
- quantizer = self ._export_recipe .quantization_recipe .get_quantizer ()
550
- if quantizer is not None :
551
- quantize_stage = QuantizeStage (quantizer = quantizer )
554
+ quantizers = self ._export_recipe .quantization_recipe .get_quantizers ()
555
+ if quantizers is not None :
556
+ quantize_stage = QuantizeStage (quantizers = quantizers )
552
557
self ._pipeline .append (quantize_stage )
553
558
554
559
# Create the edge transform and lower stage
@@ -661,6 +666,22 @@ def get_executorch_program(self) -> Program:
661
666
)
662
667
return self ._executorch_program_manager .executorch_program
663
668
669
+ def get_executorch_program_manager (self ) -> ExecutorchProgramManager :
670
+ """
671
+ Get the ExecutorchProgramManager.
672
+
673
+ Returns:
674
+ The ExecutorchProgramManager
675
+
676
+ Raises:
677
+ RuntimeError: If the executorch program manager is not initialized
678
+ """
679
+ if self ._executorch_program_manager is None :
680
+ raise RuntimeError (
681
+ "Executorch program manager is not initialized. Run export() first."
682
+ )
683
+ return self ._executorch_program_manager
684
+
664
685
def get_pte_buffer (self ) -> bytes :
665
686
"""
666
687
Get the PTE buffer as bytes.
@@ -677,6 +698,20 @@ def get_pte_buffer(self) -> bytes:
677
698
)
678
699
return self ._executorch_program_manager .buffer
679
700
701
+ def save_to_pte (self , output_name : str ) -> None :
702
+ """
703
+ Save the model to a .pte file.
704
+
705
+ Args:
706
+ output_name (Optional[str]): The name of the .pte file.
707
+ """
708
+ assert output_name , "Need a valid output name"
709
+ if self ._executorch_program_manager is None :
710
+ raise RuntimeError (
711
+ "Executorch program manager is not initialized. Run export() first."
712
+ )
713
+ save_pte_program (self ._executorch_program_manager , output_name )
714
+
680
715
def get_example_input (
681
716
self , method_name : str = "forward"
682
717
) -> Tuple [torch .Tensor , ...]:
0 commit comments