8
8
9
9
import json
10
10
import os
11
- from pathlib import Path
11
+ from typing import Dict , Tuple
12
12
13
13
import torch
14
+ from executorch .examples .models .checkpoint import (
15
+ get_checkpoint_dtype ,
16
+ get_default_model_resource_dir ,
17
+ )
14
18
15
19
from executorch .examples .models .llama2 .llama_transformer import ModelArgs , Transformer
16
20
17
21
try :
18
22
from .fairseq2 import convert_to_llama_checkpoint
19
23
20
24
except ImportError :
21
-
22
25
def convert_to_llama_checkpoint (** kwargs ):
23
26
raise NotImplementedError (
24
27
"Please install fairseq2 with `pip install fairseq2`."
@@ -30,48 +33,29 @@ def convert_to_llama_checkpoint(**kwargs):
30
33
31
34
class Llama2Model (EagerModelBase ):
32
35
def __init__ (self , ** kwargs ):
33
- import pkg_resources
34
-
35
- # default path to the resource file
36
- # It currently supports 3 ways of specifying the checkpoint location:
37
- # 1. Using default path locates in examples/models/llama2/params
38
- # 2. Passing in the checkpoint path and params via kwargs
39
- # 3. Using the path from pkg_resources, only works with buck2
40
- try :
41
- # The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42
- # pyre-ignore
43
- from executorch .examples .models .llama2 import params
44
-
45
- ckpt_dir = Path (
46
- pkg_resources .resource_filename (
47
- "executorch.examples.models.llama2" , "params"
48
- )
49
- )
50
- except :
51
- # The 1st way
52
- ckpt_dir = Path (__file__ ).absolute ().parent / "params"
53
-
54
- # Check if checkpoint_dir was provided for a sharded checkpoint.
55
- checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
36
+ resource_dir = get_default_model_resource_dir ("llama2" )
56
37
57
38
# Use single checkpoint file.
58
- checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
39
+ checkpoint_path = kwargs .get ("checkpoint" , resource_dir / "demo_rand_params.pth" )
40
+ params_path = kwargs .get ("params" , resource_dir / "demo_config.json" )
59
41
60
- params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
42
+ # Check if checkpoint_dir was provided for a sharded checkpoint.
43
+ checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
61
44
62
45
self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
63
46
self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
64
47
self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
65
48
self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
66
49
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
67
-
68
50
self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
69
51
self .args = kwargs .get ("args" , None )
52
+
70
53
# The example is using a dummy small model with random weights for demo purpose only.
71
- # Follow the instruction in https://github.com/facebookresearch/llama to download the model
54
+ # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
72
55
device = "cpu"
73
56
# flake8: noqa: TOR102
74
57
cps = []
58
+ # Load sharded checkpoint.
75
59
if checkpoint_dir is not None :
76
60
# Load multiple checkpoint; ignore the single path.
77
61
checkpoint_path = None
@@ -98,8 +82,11 @@ def __init__(self, **kwargs):
98
82
else :
99
83
# Do not duplicate layers shared between each checkpoint.
100
84
checkpoint [key ] = cps [0 ][key ]
85
+ # Load single checkpoint.
101
86
else :
102
87
checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
88
+
89
+ # If given checkpoint is fairseq, convert to llama checkpoint.
103
90
fairseq2_checkpoint = kwargs .get ("fairseq2" , False )
104
91
if fairseq2_checkpoint :
105
92
print ("Using fairseq2 checkpoint" )
@@ -108,12 +95,12 @@ def __init__(self, **kwargs):
108
95
# NB: some checkpoint contains a "model" field, which is the actual weights dict
109
96
checkpoint = checkpoint ["model" ]
110
97
98
+ # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111
99
if (not fairseq2_checkpoint ) and checkpoint .get (
112
100
"final_proj.weight" , None
113
101
) is not None :
114
- print (
102
+ raise ValueError (
115
103
"""
116
-
117
104
************************************************************
118
105
This looks like a Fairseq2 checkpoint (based on the presence
119
106
of `final_proj.weight`.
@@ -125,44 +112,28 @@ def __init__(self, **kwargs):
125
112
"""
126
113
)
127
114
128
- # get checkpoint dtype
129
- self .dtype = None
130
- if len (checkpoint ) > 0 :
131
- first_key = next (iter (checkpoint ))
132
- first = checkpoint [first_key ]
133
- self .dtype = first .dtype
134
- mismatched_dtypes = [
135
- (key , value .dtype )
136
- for key , value in checkpoint .items ()
137
- if value .dtype != self .dtype
138
- ]
139
- if len (mismatched_dtypes ) > 0 :
140
- print (
141
- f"Mixed dtype model. Dtype of { first_key } : { first .dtype } . Mismatches in the checkpoint: { mismatched_dtypes } "
142
- )
115
+ # Get checkpoint dtype.
116
+ self .dtype = get_checkpoint_dtype (checkpoint )
117
+
143
118
with open (params_path , "r" ) as f :
144
119
params = json .loads (f .read ())
145
120
output_prune_map = None
146
121
if self .output_prune_map_path is not None :
147
122
with open (self .output_prune_map_path , "r" ) as f :
148
123
output_prune_map = json .load (f )
149
- # change keys from string to int (json only supports string keys)
124
+ # Change keys from string to int (json only supports string keys).
150
125
output_prune_map = {int (k ): v for (k , v ) in output_prune_map .items ()}
151
- max_seq_len = self .max_seq_len
152
- max_batch_size = 1
126
+
153
127
model_args : ModelArgs = ModelArgs (
154
- max_seq_len = max_seq_len ,
155
- max_batch_size = max_batch_size ,
128
+ max_seq_len = self . max_seq_len ,
129
+ max_batch_size = 1 ,
156
130
use_kv_cache = self .use_kv_cache ,
157
131
use_sdpa_with_kv_cache_op = self .use_sdpa_with_kv_cache_op ,
158
132
generate_full_logits = self .generate_full_logits ,
159
133
output_prune_map = output_prune_map ,
160
134
enable_dynamic_shape = self .enable_dynamic_shape ,
161
135
** params ,
162
136
)
163
- if kwargs .get ("fairseq2" , False ):
164
- print ("Using fairseq2 checkpoint" )
165
- checkpoint = convert_to_llama_checkpoint (checkpoint = checkpoint )
166
137
if kwargs .get ("verbose" , False ):
167
138
print ("============= weights ================" )
168
139
print ("{key} : {weights.numel()} : {weights.size()}" )
@@ -234,13 +205,13 @@ def __init__(self, **kwargs):
234
205
print (unexpected )
235
206
print ("============= /unexpected ================" )
236
207
237
- # prune the output layer if output_prune_map is provided
208
+ # Prune the output layer if output_prune_map is provided
238
209
if output_prune_map is not None :
239
210
from .source_transformation .prune_output import prune_output_vocab
240
211
241
212
self .model_ = prune_output_vocab (self .model_ , output_prune_map )
242
213
243
- def get_eager_model (self ):
214
+ def get_eager_model (self ) -> torch . nn . Module :
244
215
if self .dtype :
245
216
# convert to the type of the provided checkpoint
246
217
# input and output are torch.long, so signature unchanged
0 commit comments