Skip to content

Commit 79c15ef

Browse files
authored
Fix llava model definition for export
Differential Revision: D61044795 Pull Request resolved: #4650
1 parent c5a816e commit 79c15ef

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

examples/models/llava/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ def __init__(
9999
assign=True,
100100
)
101101
self.image_processor = image_processor
102-
self.vision_tower = self.get_model().vision_tower
103-
self.mm_projector = self.get_model().mm_projector
104102

105103
def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
106104
state_dict = self.model_.state_dict()
@@ -143,8 +141,8 @@ def get_model(self):
143141

144142
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
145143
images = images.to(dtype=self.get_model().dtype)
146-
image_features = self.vision_tower(images)
147-
image_features = self.mm_projector(image_features)
144+
image_features = self.get_model().vision_tower(images)
145+
image_features = self.get_model().mm_projector(image_features)
148146
return image_features
149147

150148
def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)