21
21
from executorch .examples .models .llama .source_transformation .sdpa import (
22
22
replace_sdpa_with_custom_op ,
23
23
)
24
+
24
25
from executorch .examples .models .llava .image_util import prepare_image
25
26
from executorch .examples .models .model_base import EagerModelBase
26
27
from PIL import Image
@@ -48,6 +49,7 @@ def __init__(
48
49
self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
49
50
self .model_ = llava_model
50
51
self .image_processor = image_processor
52
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `config`.
51
53
self .vision_feature_layer = self .model_ .config .vision_feature_layer
52
54
self .vision_feature_select_strategy = (
53
55
self .model_ .config .vision_feature_select_strategy
@@ -76,6 +78,7 @@ def __init__(
76
78
)
77
79
78
80
def _translate_state_dict_for_text_model (self ) -> Dict [str , Any ]:
81
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
79
82
state_dict = self .model_ .language_model .state_dict ()
80
83
key_map = {
81
84
# fmt: off
@@ -128,9 +131,11 @@ def get_model(self):
128
131
return self .model_ .get_model ()
129
132
130
133
def embed_tokens (self , tokens : torch .Tensor ) -> torch .Tensor :
134
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
131
135
return self .model_ .language_model .model .embed_tokens (tokens )
132
136
133
137
def encode_images (self , images : torch .Tensor ) -> torch .Tensor :
138
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
134
139
images = images .to (dtype = self .model_ .dtype )
135
140
if type (images ) is list :
136
141
image_features = []
@@ -144,15 +149,19 @@ def encode_images(self, images: torch.Tensor) -> torch.Tensor:
144
149
image_feature = self ._feature_select (image_forward_out ).to (image .dtype )
145
150
image_features .append (image_feature )
146
151
else :
152
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `vision_tower`.
147
153
image_forward_outs = self .model_ .vision_tower (
154
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `device`.
148
155
images .to (device = self .model_ .device , dtype = self .model_ .dtype ),
149
156
output_hidden_states = True ,
150
157
)
151
158
image_features = self ._feature_select (image_forward_outs ).to (images .dtype )
159
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `multi_modal_projector`.
152
160
image_features = self .model_ .multi_modal_projector (image_features )
153
161
return image_features
154
162
155
163
def image_preprocess (self , img : torch .Tensor ) -> torch .Tensor :
164
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `crop_size`.
156
165
target_h = self .image_processor .crop_size ["height" ]
157
166
target_w = self .image_processor .crop_size ["width" ]
158
167
# pad the image with median rgb value, to make a square
@@ -195,10 +204,15 @@ def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
195
204
# print(resized.shape)
196
205
# cropped = F.center_crop(img, output_size=[w, w])
197
206
# print(cropped.shape)
207
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `rescale_factor`.
198
208
scaled = resized * self .image_processor .rescale_factor
199
209
# print(scaled)
200
210
normed = F .normalize (
201
- scaled , self .image_processor .image_mean , self .image_processor .image_std
211
+ scaled ,
212
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_mean`.
213
+ self .image_processor .image_mean ,
214
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_vision_objects.CLIPImageProcessor` has no attribute `image_std`.
215
+ self .image_processor .image_std ,
202
216
)
203
217
# print(normed)
204
218
return normed .unsqueeze (0 )
@@ -249,7 +263,9 @@ def prefill_ref(
249
263
) -> torch .Tensor :
250
264
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
251
265
embeds = self .prefill_embedding (prompt_before_image , images , prompt_after_image )
266
+ # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
252
267
return LlamaForCausalLM .forward (
268
+ # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
253
269
self .model_ .language_model ,
254
270
inputs_embeds = embeds ,
255
271
return_dict = False ,
@@ -268,12 +284,16 @@ class LlavaModel(EagerModelBase):
268
284
def __init__ (self , use_sdpa_with_kv_cache_op = True , max_seq_len = 768 ):
269
285
self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
270
286
self .max_seq_len = max_seq_len
271
- self .processor = AutoProcessor .from_pretrained ("llava-hf/llava-1.5-7b-hf" )
287
+ self .processor = AutoProcessor .from_pretrained (
288
+ "llava-hf/llava-1.5-7b-hf" ,
289
+ revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
290
+ )
272
291
self .tokenizer = self .processor .tokenizer
273
292
self .image_processor = self .processor .image_processor
274
293
self .model = LlavaForConditionalGeneration .from_pretrained (
275
294
"llava-hf/llava-1.5-7b-hf" ,
276
295
device_map = "cpu" ,
296
+ revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
277
297
)
278
298
self .image = Image .open (
279
299
requests .get (
0 commit comments