Skip to content

Commit 1a85097

Browse files
committed
Update on "Use llm_config instead of args in export_llama functions"
Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927) [ghstack-poisoned]
2 parents 4760311 + a287f35 commit 1a85097

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ runtime.python_library(
152152
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
153153
"//caffe2:torch",
154154
"//executorch/examples/models/llama/config:llm_config",
155+
"//executorch/examples/models/llama/config:llm_config_utils",
155156
"//executorch/backends/vulkan/_passes:vulkan_passes",
156157
"//executorch/exir/passes:init_mutable_pass",
157158
"//executorch/examples/models:model_base",

examples/models/llama/config/llm_config_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,17 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
121121
if hasattr(args, "coreml"):
122122
llm_config.backend.coreml.enabled = args.coreml
123123
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
124-
llm_config.backend.coreml.preserve_sdpa = getattr(args, "coreml_preserve_sdpa", False)
124+
llm_config.backend.coreml.preserve_sdpa = getattr(
125+
args, "coreml_preserve_sdpa", False
126+
)
125127
if hasattr(args, "coreml_quantize") and args.coreml_quantize:
126128
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
127129
if hasattr(args, "coreml_ios"):
128130
llm_config.backend.coreml.ios = args.coreml_ios
129131
if hasattr(args, "coreml_compute_units"):
130-
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(args.coreml_compute_units)
132+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
133+
args.coreml_compute_units
134+
)
131135

132136
# Vulkan
133137
if hasattr(args, "vulkan"):

0 commit comments

Comments
 (0)