17
17
)
18
18
19
19
from executorch .examples .models .model_base import EagerModelBase
20
+ from executorch .extension .llm .modules .attention import replace_mha_with_inference_mha
20
21
from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
21
22
from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
22
23
@@ -53,7 +54,7 @@ def __init__(self, **kwargs):
53
54
self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
54
55
self .verbose = kwargs .get ("verbose" , False )
55
56
self .args = kwargs .get ("args" , None )
56
- self .dtype = None
57
+ self .dtype = kwargs . get ( "dtype" , torch . float16 )
57
58
self .use_checkpoint = False
58
59
59
60
ckpt_dir = get_default_model_resource_dir (__file__ )
@@ -72,7 +73,7 @@ def __init__(self, **kwargs):
72
73
dtype = torch .bool ,
73
74
)
74
75
)
75
- self .input_pos = torch .arange (self .max_seq_len )
76
+ self .input_pos = torch .arange (self .max_seq_len , dtype = torch . int64 )
76
77
77
78
# Load checkpoint and params.
78
79
device = "cpu"
@@ -107,6 +108,9 @@ def __init__(self, **kwargs):
107
108
rope_base = params ["rope_theta" ],
108
109
intermediate_dim = params ["intermediate_dim" ],
109
110
)
111
+
112
+ # Source transformation for MultiHeadAttention
113
+ self .model_ = replace_mha_with_inference_mha (self .model_ )
110
114
# Save params for future use.
111
115
for param_name , param_val in params .items ():
112
116
setattr (self .model_ , param_name , param_val )
@@ -147,39 +151,46 @@ def __init__(self, **kwargs):
147
151
self .model_ .setup_caches (
148
152
batch_size = 1 ,
149
153
dtype = self .dtype ,
154
+ encoder_max_seq_len = self .encoder_max_seq_len ,
150
155
decoder_max_seq_len = self .max_seq_len ,
151
156
)
157
+ # number of tokens for example input
158
+ self .n_tokens = 34
159
+ self .model_ .to (self .dtype )
152
160
153
161
def get_eager_model (self ) -> torch .nn .Module :
154
- if self .dtype :
155
- return self .model_ .to (self .dtype )
156
- else :
157
- return self .model_ .to (torch .float16 )
162
+ return self .model_
158
163
159
164
def get_example_inputs (self ):
160
- return (torch .ones (1 , 32 , dtype = torch .long ),)
165
+ return (torch .ones (1 , self . n_tokens , dtype = torch .int64 ),)
161
166
162
167
def get_example_kwarg_inputs (self ):
163
168
# For export we must use the prefill versions of the
164
169
# causal mask and input_pos.
170
+ # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
165
171
if self .use_kv_cache :
166
172
return {
167
- "input_pos" : self .input_pos [None , :32 ],
168
- "mask" : self .causal_mask [None , :32 ],
169
- # "encoder_input": None,
170
- # "encoder_mask": None,
173
+ "input_pos" : self .input_pos [None , : self .n_tokens ],
174
+ "mask" : self .causal_mask [None , : self .n_tokens ],
175
+ "encoder_input" : torch .randn (
176
+ 1 , self .encoder_max_seq_len , self .model_ .dim , dtype = self .dtype
177
+ ),
178
+ "encoder_mask" : torch .ones (
179
+ [1 , self .n_tokens , self .encoder_max_seq_len ], dtype = torch .bool
180
+ ),
171
181
}
172
182
else :
173
183
return None
174
184
175
185
def get_dynamic_shapes (self ):
176
186
batch_size = 1
177
187
dim_seq_len = torch .export .Dim ("token_dim" , min = 1 , max = self .max_seq_len )
188
+ # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
178
189
if self .use_kv_cache :
179
190
dynamic_shapes = {
180
191
"tokens" : {0 : batch_size , 1 : dim_seq_len },
181
- # "encoder_input": {0: 1, 1: dim_enc, 2: 4096} ,
182
- # "encoder_mask": {0: 1, 1: dim , 2: dim_enc },
192
+ "encoder_input" : None ,
193
+ "encoder_mask" : {0 : 1 , 1 : dim_seq_len , 2 : None },
183
194
"mask" : {0 : batch_size , 1 : dim_seq_len , 2 : None },
184
195
"input_pos" : {0 : batch_size , 1 : dim_seq_len },
185
196
}
0 commit comments