@@ -685,49 +685,66 @@ class Blip2Qformer(Blip2Base):
685
685
...
686
686
def generate (
687
687
self ,
688
- samples ,
689
- use_nucleus_sampling = False ,
690
- num_beams = 3 ,
691
- max_length = 30 ,
692
- min_length = 10 ,
693
- top_p = 0.9 ,
694
- repetition_penalty = 1.0 ,
688
+ samples , # 输入样本,包含图像和可选文本
689
+ use_nucleus_sampling = False , # 是否使用核采样(top-p采样)
690
+ num_beams = 3 , # beam search的beam数量
691
+ max_length = 30 , # 生成文本的最大长度
692
+ min_length = 10 , # 生成文本的最小长度
693
+ top_p = 0.9 , # 核采样的概率阈值
694
+ repetition_penalty = 1.0 , # 重复惩罚系数
695
695
):
696
+ # 1. 图像编码阶段
696
697
image = samples[" image" ]
697
- image_embeds = self .ln_vision(self .visual_encoder(image))
698
+ # 通过视觉编码器(如ViT)提取图像特征 (B, 257, D)
699
+ image_embeds = self .ln_vision(self .visual_encoder(image))
698
700
701
+ # 2. 处理beam search扩展
699
702
if not use_nucleus_sampling:
703
+ # 如果是beam search,需要复制图像特征以匹配beam数量
704
+ # (B, 257, D) -> (B*num_beams, 257, D)
700
705
image_embeds = image_embeds.repeat_interleave(num_beams, dim = 0 )
701
706
else :
707
+ # 核采样时不扩展beam
702
708
num_beams = 1
709
+
710
+ # 创建图像注意力掩码(全1,表示所有图像token有效)
703
711
image_atts = torch.ones(image_embeds.size()[:- 1 ], dtype = torch.long).to(
704
712
image.device
705
713
)
706
714
715
+ # 3. 准备生成参数
707
716
model_kwargs = {
708
- " encoder_hidden_states" : image_embeds,
709
- " encoder_attention_mask" : image_atts,
717
+ " encoder_hidden_states" : image_embeds, # 图像特征作为cross-attention的输入
718
+ " encoder_attention_mask" : image_atts, # 图像注意力掩码
710
719
}
711
720
721
+ # 4. 初始化文本输入(以BOS token开头)
722
+ # 形状: (batch_size, 1),初始为[BOS]
712
723
input_ids = (
713
724
torch.LongTensor(image.size(0 ), 1 )
714
725
.fill_(self .tokenizer.bos_token_id)
715
726
.to(image.device)
716
727
)
728
+
729
+ # 5. 扩展可学习的query tokens
730
+ # query_tokens形状: (batch_size, num_query_tokens, D)
717
731
query_tokens = self .query_tokens.expand(image_embeds.shape[0 ], - 1 , - 1 )
718
732
733
+ # 6. 调用Q-Former的生成方法
719
734
outputs = self .Qformer.generate(
720
- input_ids = input_ids,
721
- query_embeds = query_tokens,
722
- max_length = max_length,
723
- min_length = min_length,
724
- num_beams = num_beams,
725
- do_sample = use_nucleus_sampling,
726
- top_p = top_p,
727
- eos_token_id = self .tokenizer.sep_token_id,
728
- pad_token_id = self .tokenizer.pad_token_id,
729
- ** model_kwargs
735
+ input_ids = input_ids, # 初始文本token [BOS]
736
+ query_embeds = query_tokens, # 可学习query tokens
737
+ max_length = max_length, # 最大生成长度
738
+ min_length = min_length, # 最小生成长度
739
+ num_beams = num_beams, # beam数量
740
+ do_sample = use_nucleus_sampling, # 是否采样
741
+ top_p = top_p, # 核采样参数
742
+ eos_token_id = self .tokenizer.sep_token_id, # 结束符
743
+ pad_token_id = self .tokenizer.pad_token_id, # 填充符
744
+ ** model_kwargs # 图像特征和掩码
730
745
)
746
+
747
+ # 7. 解码生成的token id为文本
731
748
captions = self .tokenizer.batch_decode(outputs, skip_special_tokens = True )
732
749
return captions
733
750
```
0 commit comments