|
21 | 21 | DuplicateDynamicQuantChainPass,
|
22 | 22 | )
|
23 | 23 | from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
|
24 |
| -from executorch.exir import EdgeProgramManager |
| 24 | +from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower |
25 | 25 | from executorch.exir.backend.partitioner import Partitioner
|
26 | 26 |
|
27 | 27 | from executorch.exir.backend.utils import format_delegated_graph
|
@@ -216,6 +216,7 @@ def export(self) -> "LLMEdgeManager":
|
216 | 216 | )
|
217 | 217 | # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
|
218 | 218 | # `Module`.
|
| 219 | + self.pre_autograd_exported_program = exported_module |
219 | 220 | self.pre_autograd_graph_module = exported_module.module()
|
220 | 221 | if hasattr(self.args, "export_only") and self.args.export_only:
|
221 | 222 | torch.export.save(exported_module, self.args.output_name)
|
@@ -305,51 +306,51 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
|
305 | 306 | ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
|
306 | 307 | logging.info(f"Using pt2e {quantizers} to quantizing the model...")
|
307 | 308 |
|
| 309 | + if not quantizers: |
| 310 | + logging.info("No quantizer provided, passing...") |
| 311 | + return self |
| 312 | + |
308 | 313 | # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
|
309 | 314 | # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
|
310 |
| - if quantizers: |
311 |
| - with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
312 |
| - if self.verbose: |
313 |
| - logging.info(f"Applied quantizers: {quantizers}") |
314 |
| - composed_quantizer = ComposableQuantizer(quantizers) |
315 |
| - assert ( |
316 |
| - self.pre_autograd_graph_module is not None |
317 |
| - ), "Please run export() first" |
318 |
| - m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) |
| 315 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
| 316 | + if self.verbose: |
| 317 | + logging.info(f"Applied quantizers: {quantizers}") |
| 318 | + composed_quantizer = ComposableQuantizer(quantizers) |
| 319 | + assert ( |
| 320 | + self.pre_autograd_graph_module is not None |
| 321 | + ), "Please run export() first" |
| 322 | + m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) |
| 323 | + logging.info( |
| 324 | + f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" |
| 325 | + ) |
| 326 | + # Calibrate |
| 327 | + if ( |
| 328 | + self.calibration_tasks is not None |
| 329 | + and self.calibration_limit is not None |
| 330 | + and self.calibration_seq_length is not None |
| 331 | + and self.calibration_data is not None |
| 332 | + and self.tokenizer_path is not None |
| 333 | + ): |
319 | 334 | logging.info(
|
320 | 335 | f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
|
321 | 336 | )
|
322 |
| - # Calibrate |
323 |
| - if ( |
324 |
| - self.calibration_tasks is not None |
325 |
| - and self.calibration_limit is not None |
326 |
| - and self.calibration_seq_length is not None |
327 |
| - and self.calibration_data is not None |
328 |
| - and self.tokenizer_path is not None |
329 |
| - ): |
330 |
| - logging.info( |
331 |
| - f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" |
332 |
| - ) |
333 |
| - self.pt2e_calibrate( |
334 |
| - prepared_module=m, |
335 |
| - calibration_tasks=self.calibration_tasks, |
336 |
| - calibration_limit=self.calibration_limit, |
337 |
| - calibration_seq_length=self.calibration_seq_length, |
338 |
| - calibration_data=self.calibration_data, |
339 |
| - tokenizer_path=self.tokenizer_path, |
340 |
| - ) |
341 |
| - else: |
342 |
| - logging.info( |
343 |
| - "No calibration provided, using dummy input to calibrate..." |
344 |
| - ) |
345 |
| - m(*self.example_inputs) |
346 |
| - m = convert_pt2e(m) |
347 |
| - DuplicateDynamicQuantChainPass()(m) |
348 |
| - self.pre_autograd_graph_module = m |
349 |
| - return self |
350 |
| - else: |
351 |
| - logging.info("No quantizer provided, passing...") |
352 |
| - return self |
| 337 | + self.pt2e_calibrate( |
| 338 | + prepared_module=m, |
| 339 | + calibration_tasks=self.calibration_tasks, |
| 340 | + calibration_limit=self.calibration_limit, |
| 341 | + calibration_seq_length=self.calibration_seq_length, |
| 342 | + calibration_data=self.calibration_data, |
| 343 | + tokenizer_path=self.tokenizer_path, |
| 344 | + ) |
| 345 | + else: |
| 346 | + logging.info( |
| 347 | + "No calibration provided, using dummy input to calibrate..." |
| 348 | + ) |
| 349 | + m(*self.example_inputs, **self.example_kwarg_inputs) |
| 350 | + m = convert_pt2e(m) |
| 351 | + DuplicateDynamicQuantChainPass()(m) |
| 352 | + self.pre_autograd_graph_module = m |
| 353 | + return self |
353 | 354 |
|
354 | 355 | def export_to_edge(self) -> "LLMEdgeManager":
|
355 | 356 | """
|
@@ -415,6 +416,18 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
|
415 | 416 |
|
416 | 417 | return self
|
417 | 418 |
|
| 419 | + def to_edge_transform_and_lower(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": |
| 420 | + if partitioners is None: |
| 421 | + logging.info("No partitioner provided, skipping backend lowering...") |
| 422 | + breakpoint() |
| 423 | + edge_config = self._get_edge_config() |
| 424 | + self.edge_manager = to_edge_transform_and_lower( |
| 425 | + self.pre_autograd_exported_program, |
| 426 | + partitioner=partitioners, |
| 427 | + compile_config=edge_config, |
| 428 | + ) |
| 429 | + return self |
| 430 | + |
418 | 431 | def to_executorch(self) -> "LLMEdgeManager":
|
419 | 432 | """
|
420 | 433 | Lower the model to executorch and get an ExecutorchProgram.
|
|
0 commit comments