8
8
9
9
import json
10
10
import os
11
- from pathlib import Path
11
+ from typing import Dict , Tuple
12
12
13
13
import torch
14
+ from executorch .examples .models .checkpoint import (
15
+ get_checkpoint_dtype ,
16
+ get_default_model_resource_dir ,
17
+ )
14
18
15
19
from executorch .examples .models .llama2 .llama_transformer import ModelArgs , Transformer
16
20
@@ -30,48 +34,29 @@ def convert_to_llama_checkpoint(**kwargs):
30
34
31
35
class Llama2Model (EagerModelBase ):
32
36
def __init__ (self , ** kwargs ):
33
- import pkg_resources
34
-
35
- # default path to the resource file
36
- # It currently supports 3 ways of specifying the checkpoint location:
37
- # 1. Using default path locates in examples/models/llama2/params
38
- # 2. Passing in the checkpoint path and params via kwargs
39
- # 3. Using the path from pkg_resources, only works with buck2
40
- try :
41
- # The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42
- # pyre-ignore
43
- from executorch .examples .models .llama2 import params
44
-
45
- ckpt_dir = Path (
46
- pkg_resources .resource_filename (
47
- "executorch.examples.models.llama2" , "params"
48
- )
49
- )
50
- except :
51
- # The 1st way
52
- ckpt_dir = Path (__file__ ).absolute ().parent / "params"
53
-
54
- # Check if checkpoint_dir was provided for a sharded checkpoint.
55
- checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
37
+ ckpt_dir = get_default_model_resource_dir ()
56
38
57
39
# Use single checkpoint file.
58
40
checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
59
-
60
41
params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
61
42
43
+ # Check if checkpoint_dir was provided for a sharded checkpoint.
44
+ checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
45
+
62
46
self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
63
47
self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
64
48
self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
65
49
self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
66
50
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
67
-
68
51
self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
69
52
self .args = kwargs .get ("args" , None )
53
+
70
54
# The example is using a dummy small model with random weights for demo purpose only.
71
- # Follow the instruction in https://github.com/facebookresearch/llama to download the model
55
+ # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
72
56
device = "cpu"
73
57
# flake8: noqa: TOR102
74
58
cps = []
59
+ # Load sharded checkpoint.
75
60
if checkpoint_dir is not None :
76
61
# Load multiple checkpoint; ignore the single path.
77
62
checkpoint_path = None
@@ -98,8 +83,11 @@ def __init__(self, **kwargs):
98
83
else :
99
84
# Do not duplicate layers shared between each checkpoint.
100
85
checkpoint [key ] = cps [0 ][key ]
86
+ # Load single checkpoint.
101
87
else :
102
88
checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
89
+
90
+ # If given checkpoint is fairseq, convert to llama checkpoint.
103
91
fairseq2_checkpoint = kwargs .get ("fairseq2" , False )
104
92
if fairseq2_checkpoint :
105
93
print ("Using fairseq2 checkpoint" )
@@ -108,12 +96,12 @@ def __init__(self, **kwargs):
108
96
# NB: some checkpoint contains a "model" field, which is the actual weights dict
109
97
checkpoint = checkpoint ["model" ]
110
98
99
+ # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111
100
if (not fairseq2_checkpoint ) and checkpoint .get (
112
101
"final_proj.weight" , None
113
102
) is not None :
114
- print (
103
+ raise ValueError (
115
104
"""
116
-
117
105
************************************************************
118
106
This looks like a Fairseq2 checkpoint (based on the presence
119
107
of `final_proj.weight`.
@@ -125,44 +113,28 @@ def __init__(self, **kwargs):
125
113
"""
126
114
)
127
115
128
- # get checkpoint dtype
129
- self .dtype = None
130
- if len (checkpoint ) > 0 :
131
- first_key = next (iter (checkpoint ))
132
- first = checkpoint [first_key ]
133
- self .dtype = first .dtype
134
- mismatched_dtypes = [
135
- (key , value .dtype )
136
- for key , value in checkpoint .items ()
137
- if value .dtype != self .dtype
138
- ]
139
- if len (mismatched_dtypes ) > 0 :
140
- print (
141
- f"Mixed dtype model. Dtype of { first_key } : { first .dtype } . Mismatches in the checkpoint: { mismatched_dtypes } "
142
- )
116
+ # Get checkpoint dtype.
117
+ self .dtype = get_checkpoint_dtype (checkpoint )
118
+
143
119
with open (params_path , "r" ) as f :
144
120
params = json .loads (f .read ())
145
121
output_prune_map = None
146
122
if self .output_prune_map_path is not None :
147
123
with open (self .output_prune_map_path , "r" ) as f :
148
124
output_prune_map = json .load (f )
149
- # change keys from string to int (json only supports string keys)
125
+ # Change keys from string to int (json only supports string keys).
150
126
output_prune_map = {int (k ): v for (k , v ) in output_prune_map .items ()}
151
- max_seq_len = self .max_seq_len
152
- max_batch_size = 1
127
+
153
128
model_args : ModelArgs = ModelArgs (
154
- max_seq_len = max_seq_len ,
155
- max_batch_size = max_batch_size ,
129
+ max_seq_len = self . max_seq_len ,
130
+ max_batch_size = 1 ,
156
131
use_kv_cache = self .use_kv_cache ,
157
132
use_sdpa_with_kv_cache_op = self .use_sdpa_with_kv_cache_op ,
158
133
generate_full_logits = self .generate_full_logits ,
159
134
output_prune_map = output_prune_map ,
160
135
enable_dynamic_shape = self .enable_dynamic_shape ,
161
136
** params ,
162
137
)
163
- if kwargs .get ("fairseq2" , False ):
164
- print ("Using fairseq2 checkpoint" )
165
- checkpoint = convert_to_llama_checkpoint (checkpoint = checkpoint )
166
138
if kwargs .get ("verbose" , False ):
167
139
print ("============= weights ================" )
168
140
print ("{key} : {weights.numel()} : {weights.size()}" )
@@ -234,13 +206,13 @@ def __init__(self, **kwargs):
234
206
print (unexpected )
235
207
print ("============= /unexpected ================" )
236
208
237
- # prune the output layer if output_prune_map is provided
209
+ # Prune the output layer if output_prune_map is provided
238
210
if output_prune_map is not None :
239
211
from .source_transformation .prune_output import prune_output_vocab
240
212
241
213
self .model_ = prune_output_vocab (self .model_ , output_prune_map )
242
214
243
- def get_eager_model (self ):
215
+ def get_eager_model (self ) -> torch . nn . Module :
244
216
if self .dtype :
245
217
# convert to the type of the provided checkpoint
246
218
# input and output are torch.long, so signature unchanged
0 commit comments