Skip to content

Commit c6b3437

Browse files
committed
[ET-VK] Enable Partial GPU lowering via Vulkan in stories model export
Pull Request resolved: #2368 ## Context Simple change to add Vulkan Partitioner as a dependency for the llama exporter and runner, and provide a command line flag to invoke the vulkan partitioner during export. ghstack-source-id: 218421177 @exported-using-ghexport Differential Revision: [D54805831](https://our.internmc.facebook.com/intern/diff/D54805831/)
1 parent d65bfb3 commit c6b3437

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ runtime.python_library(
8282
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
8383
"//executorch/backends/xnnpack:xnnpack_backend",
8484
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
85+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
8586
"//executorch/examples/models:model_base",
8687
"//executorch/examples/models:models",
8788
"//executorch/examples/portable:utils",

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pkg_resources
1919
import torch
20+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2021
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2122
XnnpackDynamicallyQuantizedPartitioner,
2223
)
@@ -356,6 +357,7 @@ def build_args_parser() -> argparse.ArgumentParser:
356357
parser.add_argument("-2", "--fairseq2", action="store_true")
357358
parser.add_argument("-v", "--verbose", action="store_true")
358359
parser.add_argument("-X", "--xnnpack", action="store_true")
360+
parser.add_argument("-V", "--vulkan", action="store_true")
359361

360362
return parser
361363

@@ -451,6 +453,9 @@ def _export_llama(modelname, args) -> str: # noqa: C901
451453
)
452454
# partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner()
453455
modelname = f"xnnpack_{modelname}"
456+
if args.vulkan:
457+
partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
458+
modelname = f"vulkan_{modelname}"
454459

455460
builder = (
456461
load_llama_model(

examples/models/llama2/runner/targets.bzl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def define_common_targets():
3636
"//executorch/extension/module:module" + aten_suffix,
3737
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
3838
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
39-
] + (_get_operator_lib(aten)),
39+
] + (_get_operator_lib(aten)) + ([
40+
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
41+
# Therefore enable it explicitly for now to avoid failing tests
42+
"//executorch/backends/vulkan:vulkan_backend_lib",
43+
] if native.read_config("llama", "use_vulkan", "0") == "1" else []),
4044
external_deps = [
4145
"libtorch",
4246
] if aten else [],

0 commit comments

Comments
 (0)