Skip to content

Commit 18664d7

Browse files
committed
Update base for Update on "[ET-VK] Add binary op support for height and width packing GPU layouts"
## Context Enable `binary_op` to support inputs that are `HEIGHT_PACKED` and `WIDTH_PACKED`. Differential Revision: [D55031044](https://our.internmc.facebook.com/intern/diff/D55031044/) [ghstack-poisoned]
2 parents aaa48d2 + 9b5bd5e commit 18664d7

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2323
XnnpackDynamicallyQuantizedPartitioner,
2424
)
25+
from executorch.exir.backend.backend_details import CompileSpec
2526

2627
from executorch.sdk.etrecord import generate_etrecord
2728
from executorch.util.activation_memory_profiler import generate_memory_trace
@@ -366,6 +367,7 @@ def build_args_parser() -> argparse.ArgumentParser:
366367
parser.add_argument("-v", "--verbose", action="store_true")
367368
parser.add_argument("-X", "--xnnpack", action="store_true")
368369
parser.add_argument("-V", "--vulkan", action="store_true")
370+
parser.add_argument("--mps", action="store_true")
369371

370372
parser.add_argument(
371373
"--generate_etrecord",
@@ -517,6 +519,25 @@ def _export_llama(modelname, args) -> str: # noqa: C901
517519
partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
518520
modelname = f"vulkan_{modelname}"
519521

522+
if args.mps:
523+
assert (
524+
args.use_kv_cache is True
525+
), "MPS backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
526+
try:
527+
# pyre-ignore Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.mps.partition.mps_partitioner`.
528+
from executorch.backends.apple.mps.partition.mps_partitioner import (
529+
MPSPartitioner,
530+
)
531+
except ImportError:
532+
raise ImportError(
533+
"Please install the MPS backend follwing https://pytorch.org/executorch/main/build-run-mps.html"
534+
)
535+
536+
compile_specs = [CompileSpec("use_fp16", bytes([True]))]
537+
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `apple`.
538+
partitioners[MPSPartitioner.__name__] = MPSPartitioner(compile_specs)
539+
modelname = f"mps_{modelname}"
540+
520541
if args.generate_etrecord:
521542
if not builder_exported_to_edge.edge_manager:
522543
raise ValueError("Unable to generate etrecord due to missing edge manager.")

0 commit comments

Comments
 (0)