Skip to content

Commit 385c5e1

Browse files
committed
Update on "Reduce memory requirement on export_llama tests with no params"
For some reason, after the previous PR in the stack, test_export_llama_lib was ooming on gh actions CI and I couldn't really figure out why, since I profiled the running memory on the test both before and after they were the same. This addresses fixes the ci oom, and I've been meaning to do this anyway since if we are loading a transformer without params specified, we likely just want to test some basic functionality, so a 1 layer makes more sense than a 8 layer default. Have made sure code elsewhere is not relying on this 8 layer default atm. Differential Revision: [D75498713](https://our.internmc.facebook.com/intern/diff/D75498713) [ghstack-poisoned]
2 parents 548ef87 + 045dbf1 commit 385c5e1

File tree

4 files changed

+79
-1
lines changed

4 files changed

+79
-1
lines changed

examples/models/llama/TARGETS

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ runtime.python_binary(
8282
],
8383
deps = [
8484
":export_library",
85+
":export_llama_args",
86+
":export_llama_hydra",
8587
"//caffe2:torch",
8688
"//executorch/extension/pybindings:aten_lib",
8789
],
@@ -148,6 +150,8 @@ runtime.python_library(
148150
":source_transformation",
149151
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
150152
"//caffe2:torch",
153+
"//executorch/examples/models/llama/config:llm_config",
154+
"//executorch/examples/models/llama/config:llm_config_utils",
151155
"//executorch/backends/vulkan/_passes:vulkan_passes",
152156
"//executorch/exir/passes:init_mutable_pass",
153157
"//executorch/examples/models:model_base",
@@ -231,6 +235,40 @@ runtime.python_library(
231235
],
232236
)
233237

238+
runtime.python_library(
239+
name = "export_llama_args",
240+
srcs = [
241+
"export_llama_args.py",
242+
],
243+
_is_external_target = True,
244+
base_module = "executorch.examples.models.llama",
245+
visibility = [
246+
"//executorch/examples/...",
247+
"@EXECUTORCH_CLIENTS",
248+
],
249+
deps = [
250+
":export_library",
251+
],
252+
)
253+
254+
runtime.python_library(
255+
name = "export_llama_hydra",
256+
srcs = [
257+
"export_llama_hydra.py",
258+
],
259+
_is_external_target = True,
260+
base_module = "executorch.examples.models.llama",
261+
visibility = [
262+
"//executorch/examples/...",
263+
"@EXECUTORCH_CLIENTS",
264+
],
265+
deps = [
266+
":export_library",
267+
"//executorch/examples/models/llama/config:llm_config",
268+
"fbsource//third-party/pypi/hydra-core:hydra-core",
269+
],
270+
)
271+
234272
runtime.python_test(
235273
name = "quantized_kv_cache_test",
236274
srcs = [

examples/models/llama/config/TARGETS

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(":targets.bzl", "define_common_targets")
6+
7+
oncall("executorch")
8+
9+
define_common_targets()

examples/models/llama/config/llm_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class CoreMLConfig:
218218
enable_state: bool = False
219219
preserve_sdpa: bool = False
220220
quantize: Optional[CoreMLQuantize] = None
221-
ios: Literal[15, 16, 17, 18] = 15
221+
ios: int = 15
222222
compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY
223223

224224
def __post_init__(self):
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.python_library(
5+
name = "llm_config",
6+
srcs = [
7+
"llm_config.py",
8+
],
9+
_is_external_target = True,
10+
base_module = "executorch.examples.models.llama.config",
11+
visibility = [
12+
"//executorch/...",
13+
"@EXECUTORCH_CLIENTS",
14+
],
15+
)
16+
17+
runtime.python_library(
18+
name = "llm_config_utils",
19+
srcs = [
20+
"llm_config_utils.py",
21+
],
22+
_is_external_target = True,
23+
base_module = "executorch.examples.models.llama.config",
24+
visibility = [
25+
"//executorch/...",
26+
"@EXECUTORCH_CLIENTS",
27+
],
28+
deps = [
29+
":llm_config",
30+
],
31+
)

0 commit comments

Comments
 (0)