Skip to content

Commit 834775b

Browse files
committed
updates
1 parent f859002 commit 834775b

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -685,49 +685,66 @@ class Blip2Qformer(Blip2Base):
685685
...
686686
def generate(
687687
self,
688-
samples,
689-
use_nucleus_sampling=False,
690-
num_beams=3,
691-
max_length=30,
692-
min_length=10,
693-
top_p=0.9,
694-
repetition_penalty=1.0,
688+
samples, # 输入样本,包含图像和可选文本
689+
use_nucleus_sampling=False, # 是否使用核采样(top-p采样)
690+
num_beams=3, # beam search的beam数量
691+
max_length=30, # 生成文本的最大长度
692+
min_length=10, # 生成文本的最小长度
693+
top_p=0.9, # 核采样的概率阈值
694+
repetition_penalty=1.0, # 重复惩罚系数
695695
):
696+
# 1. 图像编码阶段
696697
image = samples["image"]
697-
image_embeds = self.ln_vision(self.visual_encoder(image))
698+
# 通过视觉编码器(如ViT)提取图像特征 (B, 257, D)
699+
image_embeds = self.ln_vision(self.visual_encoder(image))
698700

701+
# 2. 处理beam search扩展
699702
if not use_nucleus_sampling:
703+
# 如果是beam search,需要复制图像特征以匹配beam数量
704+
# (B, 257, D) -> (B*num_beams, 257, D)
700705
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
701706
else:
707+
# 核采样时不扩展beam
702708
num_beams = 1
709+
710+
# 创建图像注意力掩码(全1,表示所有图像token有效)
703711
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
704712
image.device
705713
)
706714

715+
# 3. 准备生成参数
707716
model_kwargs = {
708-
"encoder_hidden_states": image_embeds,
709-
"encoder_attention_mask": image_atts,
717+
"encoder_hidden_states": image_embeds, # 图像特征作为cross-attention的输入
718+
"encoder_attention_mask": image_atts, # 图像注意力掩码
710719
}
711720

721+
# 4. 初始化文本输入(以BOS token开头)
722+
# 形状: (batch_size, 1),初始为[BOS]
712723
input_ids = (
713724
torch.LongTensor(image.size(0), 1)
714725
.fill_(self.tokenizer.bos_token_id)
715726
.to(image.device)
716727
)
728+
729+
# 5. 扩展可学习的query tokens
730+
# query_tokens形状: (batch_size, num_query_tokens, D)
717731
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
718732

733+
# 6. 调用Q-Former的生成方法
719734
outputs = self.Qformer.generate(
720-
input_ids=input_ids,
721-
query_embeds=query_tokens,
722-
max_length=max_length,
723-
min_length=min_length,
724-
num_beams=num_beams,
725-
do_sample=use_nucleus_sampling,
726-
top_p=top_p,
727-
eos_token_id=self.tokenizer.sep_token_id,
728-
pad_token_id=self.tokenizer.pad_token_id,
729-
**model_kwargs
735+
input_ids=input_ids, # 初始文本token [BOS]
736+
query_embeds=query_tokens, # 可学习query tokens
737+
max_length=max_length, # 最大生成长度
738+
min_length=min_length, # 最小生成长度
739+
num_beams=num_beams, # beam数量
740+
do_sample=use_nucleus_sampling, # 是否采样
741+
top_p=top_p, # 核采样参数
742+
eos_token_id=self.tokenizer.sep_token_id, # 结束符
743+
pad_token_id=self.tokenizer.pad_token_id, # 填充符
744+
**model_kwargs # 图像特征和掩码
730745
)
746+
747+
# 7. 解码生成的token id为文本
731748
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
732749
return captions
733750
```

0 commit comments

Comments
 (0)