File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -1226,7 +1226,7 @@ def _load_llama_model(
1226
1226
EagerModelFactory .create_model (
1227
1227
module_name ,
1228
1228
model_class_name ,
1229
- llm_config = llm_config ,
1229
+ model_args = { " llm_config" : llm_config } ,
1230
1230
)
1231
1231
)
1232
1232
Original file line number Diff line number Diff line change @@ -36,19 +36,18 @@ def convert_to_llama_checkpoint(**kwargs):
36
36
37
37
38
38
class Llama2Model (EagerModelBase ):
39
- def __init__ (self , ** kwargs ):
39
+ def __init__ (self , llm_config ):
40
40
resource_dir = get_default_model_resource_dir (__file__ )
41
41
42
+ self .llm_config = llm_config
43
+
42
44
# Use single checkpoint file.
43
- checkpoint_path = kwargs . get ( " checkpoint" , None )
45
+ checkpoint_path = self . llm_config . base . checkpoint
44
46
# Check if checkpoint_dir was provided for a sharded checkpoint.
45
- checkpoint_dir = kwargs . get ( " checkpoint_dir" , None )
47
+ checkpoint_dir = self . llm_config . base . checkpoint_dir
46
48
47
49
# Params file.
48
- params_path = kwargs .get ("params" , None )
49
-
50
- self .llm_config = kwargs .get ("llm_config" )
51
- assert self .llm_config is not None , "llm_config must be provided"
50
+ params_path = self .llm_config .base .params
52
51
53
52
self .use_kv_cache = self .llm_config .model .use_kv_cache
54
53
self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
You can’t perform that action at this time.
0 commit comments