@@ -550,27 +550,105 @@ class BertEncoder(nn.Module):
550
550
return_dict = True ,
551
551
)
552
552
...
553
- # #================= Image Captioning ========================##
554
- decoder_input_ids = text_tokens.input_ids.clone()
555
- # 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
556
- decoder_input_ids[:, 0 ] = self .tokenizer.bos_token_id
557
- # 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
558
- labels = decoder_input_ids.masked_fill(
559
- decoder_input_ids == self .tokenizer.pad_token_id, - 100
560
- )
561
-
562
-
563
- query_atts = torch.ones(query_tokens.size()[:- 1 ], dtype = torch.long).to(
564
- image.device
565
- )
566
- attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim = 1 )
567
- lm_output = self .Qformer(
568
- decoder_input_ids,
553
+ # #================= Image Captioning ========================##
554
+ # 这一部分的目标是:根据图像特征,使用 Q-Former 解码器生成文本描述(caption)
555
+
556
+ # Step 1: 准备 decoder 的输入 token IDs
557
+ decoder_input_ids = text_tokens.input_ids.clone()
558
+ # 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
559
+ decoder_input_ids[:, 0 ] = self .tokenizer.bos_token_id
560
+
561
+ # Step 2: 构造训练目标 labels
562
+ # 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
563
+ labels = decoder_input_ids.masked_fill(
564
+ decoder_input_ids == self .tokenizer.pad_token_id, - 100
565
+ )
566
+
567
+ # Step 3: 构建 attention_mask(包含 query tokens 和 文本 token 的 mask)
568
+ # query_atts 是 query tokens 的 attention mask,全为 1(因为都是有效 token)
569
+ query_atts = torch.ones(query_tokens.size()[:- 1 ], dtype = torch.long).to(image.device)
570
+ # 将 query token 的 mask 和文本 token 的 mask 拼接在一起
571
+ attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim = 1 )
572
+
573
+ # Step 4: 调用 Q-Former 解码器进行文本生成
574
+ lm_output = self .Qformer(
575
+ decoder_input_ids, # 输入 token ID 序列(如 [BOS], dog, is...)
576
+ attention_mask = attention_mask, # 指明哪些位置是有效的(非 padding)
577
+ past_key_values = query_output.past_key_values, # 编码器输出的 key/value,包含图像信息
578
+ return_dict = True , # 返回字典格式结果
579
+ labels = labels, # 训练目标,用于计算 loss
580
+ )
581
+
582
+ # Step 5: 提取语言模型损失
583
+ loss_lm = lm_output.loss # 使用交叉熵损失衡量生成与真实之间的差异
584
+ ```
585
+ 5 . BertLMHeadModel: 自回归语言建模任务(如文本生成)
586
+
587
+ ``` python
588
+ class BertLMHeadModel (BertPreTrainedModel ):
589
+ ...
590
+ def forward (
591
+ self ,
592
+ input_ids = None ,
593
+ attention_mask = None ,
594
+ position_ids = None ,
595
+ head_mask = None ,
596
+ query_embeds = None ,
597
+ encoder_hidden_states = None ,
598
+ encoder_attention_mask = None ,
599
+ labels = None ,
600
+ past_key_values = None ,
601
+ use_cache = True ,
602
+ output_attentions = None ,
603
+ output_hidden_states = None ,
604
+ return_dict = None ,
605
+ return_logits = False ,
606
+ is_decoder = True ,
607
+ reduction = " mean" ,
608
+ ):
609
+ ...
610
+ outputs = self .bert(
611
+ input_ids,
569
612
attention_mask = attention_mask,
570
- past_key_values = query_output.past_key_values, # 传入缓存的每一层key&value
571
- return_dict = True ,
572
- labels = labels,
613
+ position_ids = position_ids,
614
+ head_mask = head_mask,
615
+ query_embeds = query_embeds,
616
+ encoder_hidden_states = encoder_hidden_states,
617
+ encoder_attention_mask = encoder_attention_mask,
618
+ past_key_values = past_key_values,
619
+ use_cache = use_cache,
620
+ output_attentions = output_attentions,
621
+ output_hidden_states = output_hidden_states,
622
+ return_dict = return_dict,
623
+ is_decoder = is_decoder,
573
624
)
574
625
575
- loss_lm = lm_output.loss
576
- ```
626
+ sequence_output = outputs[0 ]
627
+ ...
628
+ # self.cls 是一个分类头(BertOnlyMLMHead),它将每个 token 的向量映射到词汇表空间(logits)
629
+ prediction_scores = self .cls(sequence_output)
630
+ ...
631
+ lm_loss = None
632
+ if labels is not None :
633
+ # 因为我们要预测下一个 token,所以把 logits 和 labels 错位对齐:
634
+ # shifted_prediction_scores: 所有 token 的预测(除了最后一个)
635
+ shifted_prediction_scores = prediction_scores[:, :- 1 , :].contiguous()
636
+ # labels: 所有 token 的真实值(从第二个开始)
637
+ labels = labels[:, 1 :].contiguous()
638
+ loss_fct = CrossEntropyLoss(reduction = reduction, label_smoothing = 0.1 )
639
+ lm_loss = loss_fct(
640
+ shifted_prediction_scores.view(- 1 , self .config.vocab_size),
641
+ labels.view(- 1 ),
642
+ )
643
+ if reduction == " none" :
644
+ lm_loss = lm_loss.view(prediction_scores.size(0 ), - 1 ).sum(1 )
645
+ ...
646
+ return CausalLMOutputWithCrossAttentions(
647
+ loss = lm_loss,
648
+ logits = prediction_scores,
649
+ past_key_values = outputs.past_key_values,
650
+ hidden_states = outputs.hidden_states,
651
+ attentions = outputs.attentions,
652
+ cross_attentions = outputs.cross_attentions,
653
+ )
654
+ ```
0 commit comments