Skip to content

Commit f3b25e4

Browse files
authored
multimodal : add BakLLaVA conversion support (#3682)
1 parent 60abea9 commit f3b25e4

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

examples/llava/llava-surgery.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,29 @@
1616
mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
1717

1818
# store these tensors in a new dictionary and torch.save them
19-
projector = {name: checkpoint[name] for name in mm_tensors}
19+
projector = {name: checkpoint[name].float() for name in mm_tensors}
2020
torch.save(projector, f"{args.model}/llava.projector")
2121

2222
# remove these tensors from the checkpoint and save it again
2323
for name in mm_tensors:
2424
del checkpoint[name]
2525

26+
# BakLLaVA models contain CLIP tensors in it
27+
clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")]
28+
if len(clip_tensors) > 0:
29+
clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors}
30+
torch.save(clip, f"{args.model}/llava.clip")
31+
32+
# remove these tensors
33+
for name in clip_tensors:
34+
del checkpoint[name]
35+
36+
# added tokens should be removed to be able to convert Mistral models
37+
if os.path.exists(f"{args.model}/added_tokens.json"):
38+
with open(f"{args.model}/added_tokens.json", "w") as f:
39+
f.write("{}\n")
40+
41+
2642
torch.save(checkpoint, path)
2743

2844
print("Done!")

0 commit comments

Comments
 (0)