|
22 | 22 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
|
23 | 23 | XnnpackDynamicallyQuantizedPartitioner,
|
24 | 24 | )
|
| 25 | +from executorch.exir.backend.backend_details import CompileSpec |
25 | 26 |
|
26 | 27 | from executorch.sdk.etrecord import generate_etrecord
|
27 | 28 | from executorch.util.activation_memory_profiler import generate_memory_trace
|
@@ -366,6 +367,7 @@ def build_args_parser() -> argparse.ArgumentParser:
|
366 | 367 | parser.add_argument("-v", "--verbose", action="store_true")
|
367 | 368 | parser.add_argument("-X", "--xnnpack", action="store_true")
|
368 | 369 | parser.add_argument("-V", "--vulkan", action="store_true")
|
| 370 | + parser.add_argument("--mps", action="store_true") |
369 | 371 |
|
370 | 372 | parser.add_argument(
|
371 | 373 | "--generate_etrecord",
|
@@ -517,6 +519,25 @@ def _export_llama(modelname, args) -> str: # noqa: C901
|
517 | 519 | partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
|
518 | 520 | modelname = f"vulkan_{modelname}"
|
519 | 521 |
|
| 522 | + if args.mps: |
| 523 | + assert ( |
| 524 | + args.use_kv_cache is True |
| 525 | + ), "MPS backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" |
| 526 | + try: |
| 527 | + # pyre-ignore Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.mps.partition.mps_partitioner`. |
| 528 | + from executorch.backends.apple.mps.partition.mps_partitioner import ( |
| 529 | + MPSPartitioner, |
| 530 | + ) |
| 531 | + except ImportError: |
| 532 | + raise ImportError( |
| 533 | + "Please install the MPS backend follwing https://pytorch.org/executorch/main/build-run-mps.html" |
| 534 | + ) |
| 535 | + |
| 536 | + compile_specs = [CompileSpec("use_fp16", bytes([True]))] |
| 537 | + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`. |
| 538 | + partitioners[MPSPartitioner.__name__] = MPSPartitioner(compile_specs) |
| 539 | + modelname = f"mps_{modelname}" |
| 540 | + |
520 | 541 | if args.generate_etrecord:
|
521 | 542 | if not builder_exported_to_edge.edge_manager:
|
522 | 543 | raise ValueError("Unable to generate etrecord due to missing edge manager.")
|
|
0 commit comments