Skip to content

Commit d2d79f5

Browse files
committed
refactor: simplify Llama2Model constructor to take llm_config directly
ghstack-source-id: 2c66b9f Pull Request resolved: #11170
1 parent 14df400 commit d2d79f5

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def _load_llama_model(
12261226
EagerModelFactory.create_model(
12271227
module_name,
12281228
model_class_name,
1229-
llm_config=llm_config,
1229+
model_args={"llm_config": llm_config},
12301230
)
12311231
)
12321232

examples/models/llama/model.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ def convert_to_llama_checkpoint(**kwargs):
3636

3737

3838
class Llama2Model(EagerModelBase):
39-
def __init__(self, **kwargs):
39+
def __init__(self, llm_config):
4040
resource_dir = get_default_model_resource_dir(__file__)
4141

42+
self.llm_config = llm_config
43+
4244
# Use single checkpoint file.
43-
checkpoint_path = kwargs.get("checkpoint", None)
45+
checkpoint_path = self.llm_config.base.checkpoint
4446
# Check if checkpoint_dir was provided for a sharded checkpoint.
45-
checkpoint_dir = kwargs.get("checkpoint_dir", None)
47+
checkpoint_dir = self.llm_config.base.checkpoint_dir
4648

4749
# Params file.
48-
params_path = kwargs.get("params", None)
49-
50-
self.llm_config = kwargs.get("llm_config")
51-
assert self.llm_config is not None, "llm_config must be provided"
50+
params_path = self.llm_config.base.params
5251

5352
self.use_kv_cache = self.llm_config.model.use_kv_cache
5453
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache

0 commit comments

Comments
 (0)