Skip to content

Commit 97ec69c

Browse files
committed
Update on "Use llm_config instead of args in export_llama functions"
Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927) [ghstack-poisoned]
2 parents 45571eb + 5859561 commit 97ec69c

File tree

3 files changed

+5
-17
lines changed

3 files changed

+5
-17
lines changed

examples/models/llama/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import json
1010
import os
11-
from typing import Dict, Tuple
11+
from typing import Dict, Optional, Tuple
1212

1313
import torch
1414
from executorch.examples.models.checkpoint import (
@@ -37,17 +37,13 @@ def convert_to_llama_checkpoint(**kwargs):
3737

3838

3939
class Llama2Model(EagerModelBase):
40-
def __init__(self, llm_config: LlmConfig):
40+
def __init__(self, llm_config: Optional[LlmConfig] = None):
4141
resource_dir = get_default_model_resource_dir(__file__)
4242

43-
self.llm_config = llm_config
43+
self.llm_config = llm_config if llm_config else LlmConfig()
4444

45-
# Use single checkpoint file.
4645
checkpoint_path = self.llm_config.base.checkpoint
47-
# Check if checkpoint_dir was provided for a sharded checkpoint.
4846
checkpoint_dir = self.llm_config.base.checkpoint_dir
49-
50-
# Params file.
5147
params_path = self.llm_config.base.params
5248

5349
self.use_kv_cache = self.llm_config.model.use_kv_cache

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_has_expected_ops_and_op_counts(self):
4848
args.use_kv_cache = True
4949
args.verbose = True
5050

51-
builder = _export_llama(llm_config, args)
51+
builder = _export_llama(llm_config)
5252
graph_module = builder.edge_manager.exported_program().graph_module
5353
delegation_info = get_delegation_info(graph_module)
5454

examples/models/llava/export_llava.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
XNNPACKQuantizer,
1818
)
1919
from executorch.examples.models.llama.config.llm_config import LlmConfig
20-
from executorch.examples.models.llama.config.llm_config_utils import (
21-
convert_args_to_llm_config,
22-
)
2320
from executorch.examples.models.llama.export_llama_lib import (
24-
build_args_parser,
2521
get_quantizer_and_quant_params,
2622
)
2723
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
@@ -96,12 +92,8 @@ def forward(self, input_pos, embeddings):
9692
dynamic_shapes=dynamic_shapes,
9793
)
9894

99-
# (Legacy) parse args then convert to LlmConfig.
100-
parser = build_args_parser()
101-
args = parser.parse_args()
102-
llm_config = convert_args_to_llm_config(args)
103-
10495
# Manually set some LlmConfig options.
96+
llm_config = LlmConfig()
10597
llm_config.base.params = "params.json"
10698
llm_config.backend.xnnpack.enabled = True
10799
llm_config.quantization.qmode = "8da4w"

0 commit comments

Comments
 (0)