Skip to content

Commit 5e3db28

Browse files
committed
[ET-VK] Enable Partial GPU lowering via Vulkan in stories model export
Pull Request resolved: #2368 ## Context Simple change to add Vulkan Partitioner as a dependency for the llama exporter and runner, and provide a command line flag to invoke the vulkan partitioner during export. ghstack-source-id: 218708315 @exported-using-ghexport Differential Revision: [D54805831](https://our.internmc.facebook.com/intern/diff/D54805831/)
1 parent e98a7e0 commit 5e3db28

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ runtime.python_library(
8282
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
8383
"//executorch/backends/xnnpack:xnnpack_backend",
8484
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
85+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
8586
"//executorch/examples/models:model_base",
8687
"//executorch/examples/models:models",
8788
"//executorch/examples/portable:utils",

examples/models/llama2/export_llama_lib.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pkg_resources
2020
import torch
21+
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2122
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
2223
XnnpackDynamicallyQuantizedPartitioner,
2324
)
@@ -359,6 +360,7 @@ def build_args_parser() -> argparse.ArgumentParser:
359360
parser.add_argument("-2", "--fairseq2", action="store_true")
360361
parser.add_argument("-v", "--verbose", action="store_true")
361362
parser.add_argument("-X", "--xnnpack", action="store_true")
363+
parser.add_argument("-V", "--vulkan", action="store_true")
362364

363365
parser.add_argument(
364366
"--generate_etrecord",
@@ -463,6 +465,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901
463465
# partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner()
464466
modelname = f"xnnpack_{modelname}"
465467

468+
if args.vulkan:
469+
assert (
470+
args.dtype_override is None
471+
), "Vulkan backend does not support non fp32 dtypes at the moment"
472+
assert (
473+
args.quantization_mode is None
474+
), "Vulkan backend does not support quantization at the moment"
475+
476+
partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
477+
modelname = f"vulkan_{modelname}"
478+
466479
builder_exported_to_edge = (
467480
load_llama_model(
468481
checkpoint=checkpoint_path,

examples/models/llama2/runner/targets.bzl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def define_common_targets():
3636
"//executorch/extension/module:module" + aten_suffix,
3737
"//executorch/kernels/quantized:generated_lib" + aten_suffix,
3838
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
39-
] + (_get_operator_lib(aten)),
39+
] + (_get_operator_lib(aten)) + ([
40+
# Vulkan API currently cannot build on some platforms (e.g. Apple, FBCODE)
41+
# Therefore enable it explicitly for now to avoid failing tests
42+
"//executorch/backends/vulkan:vulkan_backend_lib",
43+
] if native.read_config("llama", "use_vulkan", "0") == "1" else []),
4044
external_deps = [
4145
"libtorch",
4246
] if aten else [],

0 commit comments

Comments
 (0)