Skip to content

Commit 89dc282

Browse files
committed
updates
1 parent 55917af commit 89dc282

File tree

1 file changed

+99
-21
lines changed

1 file changed

+99
-21
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -550,27 +550,105 @@ class BertEncoder(nn.Module):
550550
return_dict=True,
551551
)
552552
...
553-
##================= Image Captioning ========================##
554-
decoder_input_ids = text_tokens.input_ids.clone()
555-
# 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
556-
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
557-
# 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
558-
labels = decoder_input_ids.masked_fill(
559-
decoder_input_ids == self.tokenizer.pad_token_id, -100
560-
)
561-
562-
563-
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
564-
image.device
565-
)
566-
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
567-
lm_output = self.Qformer(
568-
decoder_input_ids,
553+
##================= Image Captioning ========================##
554+
# 这一部分的目标是:根据图像特征,使用 Q-Former 解码器生成文本描述(caption)
555+
556+
# Step 1: 准备 decoder 的输入 token IDs
557+
decoder_input_ids = text_tokens.input_ids.clone()
558+
# 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
559+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
560+
561+
# Step 2: 构造训练目标 labels
562+
# 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
563+
labels = decoder_input_ids.masked_fill(
564+
decoder_input_ids == self.tokenizer.pad_token_id, -100
565+
)
566+
567+
# Step 3: 构建 attention_mask(包含 query tokens 和 文本 token 的 mask)
568+
# query_atts 是 query tokens 的 attention mask,全为 1(因为都是有效 token)
569+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
570+
# 将 query token 的 mask 和文本 token 的 mask 拼接在一起
571+
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
572+
573+
# Step 4: 调用 Q-Former 解码器进行文本生成
574+
lm_output = self.Qformer(
575+
decoder_input_ids, # 输入 token ID 序列(如 [BOS], dog, is...)
576+
attention_mask=attention_mask, # 指明哪些位置是有效的(非 padding)
577+
past_key_values=query_output.past_key_values, # 编码器输出的 key/value,包含图像信息
578+
return_dict=True, # 返回字典格式结果
579+
labels=labels, # 训练目标,用于计算 loss
580+
)
581+
582+
# Step 5: 提取语言模型损失
583+
loss_lm = lm_output.loss # 使用交叉熵损失衡量生成与真实之间的差异
584+
```
585+
5. BertLMHeadModel: 自回归语言建模任务(如文本生成)
586+
587+
```python
588+
class BertLMHeadModel(BertPreTrainedModel):
589+
...
590+
def forward(
591+
self,
592+
input_ids=None,
593+
attention_mask=None,
594+
position_ids=None,
595+
head_mask=None,
596+
query_embeds=None,
597+
encoder_hidden_states=None,
598+
encoder_attention_mask=None,
599+
labels=None,
600+
past_key_values=None,
601+
use_cache=True,
602+
output_attentions=None,
603+
output_hidden_states=None,
604+
return_dict=None,
605+
return_logits=False,
606+
is_decoder=True,
607+
reduction="mean",
608+
):
609+
...
610+
outputs = self.bert(
611+
input_ids,
569612
attention_mask=attention_mask,
570-
past_key_values=query_output.past_key_values, # 传入缓存的每一层key&value
571-
return_dict=True,
572-
labels=labels,
613+
position_ids=position_ids,
614+
head_mask=head_mask,
615+
query_embeds=query_embeds,
616+
encoder_hidden_states=encoder_hidden_states,
617+
encoder_attention_mask=encoder_attention_mask,
618+
past_key_values=past_key_values,
619+
use_cache=use_cache,
620+
output_attentions=output_attentions,
621+
output_hidden_states=output_hidden_states,
622+
return_dict=return_dict,
623+
is_decoder=is_decoder,
573624
)
574625

575-
loss_lm = lm_output.loss
576-
```
626+
sequence_output = outputs[0]
627+
...
628+
# self.cls 是一个分类头(BertOnlyMLMHead),它将每个 token 的向量映射到词汇表空间(logits)
629+
prediction_scores = self.cls(sequence_output)
630+
...
631+
lm_loss = None
632+
if labels is not None:
633+
# 因为我们要预测下一个 token,所以把 logits 和 labels 错位对齐:
634+
# shifted_prediction_scores: 所有 token 的预测(除了最后一个)
635+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
636+
# labels: 所有 token 的真实值(从第二个开始)
637+
labels = labels[:, 1:].contiguous()
638+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
639+
lm_loss = loss_fct(
640+
shifted_prediction_scores.view(-1, self.config.vocab_size),
641+
labels.view(-1),
642+
)
643+
if reduction == "none":
644+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
645+
...
646+
return CausalLMOutputWithCrossAttentions(
647+
loss=lm_loss,
648+
logits=prediction_scores,
649+
past_key_values=outputs.past_key_values,
650+
hidden_states=outputs.hidden_states,
651+
attentions=outputs.attentions,
652+
cross_attentions=outputs.cross_attentions,
653+
)
654+
```

0 commit comments

Comments
 (0)