Skip to content

Commit 29a1431

Browse files
authored
Merge pull request #12 from BinaryOracle/master
Master
2 parents 85706fc + 89dc282 commit 29a1431

File tree

1 file changed

+295
-15
lines changed

1 file changed

+295
-15
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 295 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -349,26 +349,306 @@ class BertEmbeddings(nn.Module):
349349
>
350350
> ![Multimodal Causal Self-attention Mask](庖丁解牛BLIP2/10.png)
351351
352+
**视觉编码阶段**:
353+
354+
图像通过视觉编码器(如 ViT)编码为图像特征 image_embeds。Query tokens 通过 cross-attention 吸收图像特征,再通过 self-attention 生成压缩的视觉表示。缓存 query tokens 的 self-attention 的 past_key_values(而非 cross-attention 的 key/value)。
355+
356+
QFormer 会使用 past_key_values 缓存和复用 EncoderLayer 中 self-attention 的 key/value :
357+
358+
1. BertSelfAttention: 自注意力和交叉注意力流程统一化,每次计算后返回本次可能需要缓存的key & value
352359

353360
```python
354-
##================= Image Captioning ========================##
355-
decoder_input_ids = text_tokens.input_ids.clone()
356-
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
357-
labels = decoder_input_ids.masked_fill(
358-
decoder_input_ids == self.tokenizer.pad_token_id, -100
361+
class BertSelfAttention(nn.Module):
362+
...
363+
def forward(
364+
self,
365+
hidden_states,
366+
attention_mask=None,
367+
head_mask=None,
368+
encoder_hidden_states=None,
369+
encoder_attention_mask=None,
370+
past_key_value=None,
371+
output_attentions=False,
372+
):
373+
374+
# 判断是否为交叉注意力
375+
is_cross_attention = encoder_hidden_states is not None
376+
377+
# 交叉注意力则key和value都来自图像,key来自query tokens
378+
if is_cross_attention:
379+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
380+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
381+
attention_mask = encoder_attention_mask
382+
# 如果有缓存的key,value传入, 此时先用text embedding计算出key和value
383+
# 再和缓存的key,value在seq_len的维度拼接起来
384+
elif past_key_value is not None:
385+
key_layer = self.transpose_for_scores(self.key(hidden_states))
386+
value_layer = self.transpose_for_scores(self.value(hidden_states))
387+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) # (Batch,Heads,Seq_len,Hidden_size)
388+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
389+
else:
390+
# 自注意力
391+
key_layer = self.transpose_for_scores(self.key(hidden_states))
392+
value_layer = self.transpose_for_scores(self.value(hidden_states))
393+
394+
# 交叉注意力: 传入图像,则q来自query tokens
395+
# 自注意力: q来自query tokens 或者 text embedding
396+
mixed_query_layer = self.query(hidden_states)
397+
398+
query_layer = self.transpose_for_scores(mixed_query_layer)
399+
400+
# * 缓存key和value
401+
past_key_value = (key_layer, value_layer)
402+
403+
# 计算注意力分数
404+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
405+
406+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
407+
if attention_mask is not None:
408+
# 应用注意力掩码
409+
attention_scores = attention_scores + attention_mask
410+
411+
# softmax归一化得到注意力概率
412+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
413+
414+
if is_cross_attention and self.save_attention:
415+
self.save_attention_map(attention_probs)
416+
attention_probs.register_hook(self.save_attn_gradients)
417+
418+
# dropout防止过拟合
419+
attention_probs_dropped = self.dropout(attention_probs)
420+
421+
# 计算上下文表示
422+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
423+
424+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
425+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
426+
context_layer = context_layer.view(*new_context_layer_shape)
427+
428+
outputs = (
429+
(context_layer, attention_probs) if output_attentions else (context_layer,)
359430
)
431+
432+
# outputs 列表最后一个记录了缓存的key和value
433+
outputs = outputs + (past_key_value,)
434+
return outputs
435+
```
360436

361-
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
362-
image.device
437+
2. BertLayer: 负责组织自注意力和交叉注意力的运算流程
438+
439+
```python
440+
class BertLayer(nn.Module):
441+
...
442+
def forward(
443+
self,
444+
hidden_states, # query tokens
445+
attention_mask=None, # query token padding mask
446+
head_mask=None,
447+
encoder_hidden_states=None, # image tokens
448+
encoder_attention_mask=None, # image padding mask
449+
past_key_value=None,
450+
output_attentions=False,
451+
query_length=0,
452+
):
453+
self_attn_past_key_value = (
454+
past_key_value[:2] if past_key_value is not None else None
363455
)
364-
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
365-
lm_output = self.Qformer(
366-
decoder_input_ids,
367-
attention_mask=attention_mask,
368-
past_key_values=query_output.past_key_values,
456+
# 自注意力运算
457+
self_attention_outputs = self.attention(
458+
hidden_states,
459+
attention_mask,
460+
head_mask,
461+
output_attentions=output_attentions,
462+
past_key_value=self_attn_past_key_value, # 缓存的key和value
463+
)
464+
attention_output = self_attention_outputs[0]
465+
outputs = self_attention_outputs[1:-1]
466+
467+
present_key_value = self_attention_outputs[-1]
468+
469+
# 交叉注意力运算
470+
if query_length > 0:
471+
query_attention_output = attention_output[:, :query_length, :]
472+
if self.has_cross_attention:
473+
cross_attention_outputs = self.crossattention(
474+
query_attention_output,
475+
attention_mask,
476+
head_mask,
477+
encoder_hidden_states,
478+
encoder_attention_mask,
479+
output_attentions=output_attentions,
480+
)
481+
query_attention_output = cross_attention_outputs[0]
482+
outputs = (
483+
outputs + cross_attention_outputs[1:-1]
484+
)
485+
...
486+
outputs = (layer_output,) + outputs
487+
outputs = outputs + (present_key_value,) # outputs 列表最后一个记录了缓存的key和value
488+
return outputs
489+
```
490+
491+
3. BertEncoder: 负责组织多个 BertLayer 叠加的运算流程
492+
493+
```python
494+
class BertEncoder(nn.Module):
495+
...
496+
def forward(
497+
self,
498+
hidden_states, # query tokens
499+
attention_mask=None, # query tokens padding mask
500+
head_mask=None,
501+
encoder_hidden_states=None, # images
502+
encoder_attention_mask=None, # images padding mask
503+
past_key_values=None,
504+
use_cache=None,
505+
output_attentions=False,
506+
output_hidden_states=False,
507+
return_dict=True,
508+
query_length=0,
509+
):
510+
...
511+
for i in range(self.config.num_hidden_layers):
512+
layer_module = self.layer[i]
513+
...
514+
# 如果有缓存,则计算当前层BertLayer时,会从缓存中取出对应层先前缓存的key&value
515+
past_key_value = past_key_values[i] if past_key_values is not None else None
516+
layer_outputs = layer_module(
517+
hidden_states,
518+
attention_mask,
519+
layer_head_mask,
520+
encoder_hidden_states,
521+
encoder_attention_mask,
522+
past_key_value,
523+
output_attentions,
524+
query_length,
525+
)
526+
527+
hidden_states = layer_outputs[0]
528+
# 每一层BertLayer产生的key&value都会进行缓存
529+
if use_cache:
530+
next_decoder_cache += (layer_outputs[-1],)
531+
...
532+
return BaseModelOutputWithPastAndCrossAttentions(
533+
last_hidden_state=hidden_states,
534+
past_key_values=next_decoder_cache,
535+
hidden_states=all_hidden_states,
536+
attentions=all_self_attentions,
537+
cross_attentions=all_cross_attentions,
538+
)
539+
```
540+
541+
4. Image-Grounded Text Generation 学习目标
542+
543+
```python
544+
...
545+
query_output = self.Qformer.bert(
546+
query_embeds=query_tokens,
547+
encoder_hidden_states=image_embeds,
548+
encoder_attention_mask=image_atts,
549+
use_cache=True, # 缓存key&value
369550
return_dict=True,
370-
labels=labels,
551+
)
552+
...
553+
##================= Image Captioning ========================##
554+
# 这一部分的目标是:根据图像特征,使用 Q-Former 解码器生成文本描述(caption)
555+
556+
# Step 1: 准备 decoder 的输入 token IDs
557+
decoder_input_ids = text_tokens.input_ids.clone()
558+
# 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
559+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
560+
561+
# Step 2: 构造训练目标 labels
562+
# 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
563+
labels = decoder_input_ids.masked_fill(
564+
decoder_input_ids == self.tokenizer.pad_token_id, -100
565+
)
566+
567+
# Step 3: 构建 attention_mask(包含 query tokens 和 文本 token 的 mask)
568+
# query_atts 是 query tokens 的 attention mask,全为 1(因为都是有效 token)
569+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
570+
# 将 query token 的 mask 和文本 token 的 mask 拼接在一起
571+
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
572+
573+
# Step 4: 调用 Q-Former 解码器进行文本生成
574+
lm_output = self.Qformer(
575+
decoder_input_ids, # 输入 token ID 序列(如 [BOS], dog, is...)
576+
attention_mask=attention_mask, # 指明哪些位置是有效的(非 padding)
577+
past_key_values=query_output.past_key_values, # 编码器输出的 key/value,包含图像信息
578+
return_dict=True, # 返回字典格式结果
579+
labels=labels, # 训练目标,用于计算 loss
580+
)
581+
582+
# Step 5: 提取语言模型损失
583+
loss_lm = lm_output.loss # 使用交叉熵损失衡量生成与真实之间的差异
584+
```
585+
5. BertLMHeadModel: 自回归语言建模任务(如文本生成)
586+
587+
```python
588+
class BertLMHeadModel(BertPreTrainedModel):
589+
...
590+
def forward(
591+
self,
592+
input_ids=None,
593+
attention_mask=None,
594+
position_ids=None,
595+
head_mask=None,
596+
query_embeds=None,
597+
encoder_hidden_states=None,
598+
encoder_attention_mask=None,
599+
labels=None,
600+
past_key_values=None,
601+
use_cache=True,
602+
output_attentions=None,
603+
output_hidden_states=None,
604+
return_dict=None,
605+
return_logits=False,
606+
is_decoder=True,
607+
reduction="mean",
608+
):
609+
...
610+
outputs = self.bert(
611+
input_ids,
612+
attention_mask=attention_mask,
613+
position_ids=position_ids,
614+
head_mask=head_mask,
615+
query_embeds=query_embeds,
616+
encoder_hidden_states=encoder_hidden_states,
617+
encoder_attention_mask=encoder_attention_mask,
618+
past_key_values=past_key_values,
619+
use_cache=use_cache,
620+
output_attentions=output_attentions,
621+
output_hidden_states=output_hidden_states,
622+
return_dict=return_dict,
623+
is_decoder=is_decoder,
371624
)
372625

373-
loss_lm = lm_output.loss
374-
```
626+
sequence_output = outputs[0]
627+
...
628+
# self.cls 是一个分类头(BertOnlyMLMHead),它将每个 token 的向量映射到词汇表空间(logits)
629+
prediction_scores = self.cls(sequence_output)
630+
...
631+
lm_loss = None
632+
if labels is not None:
633+
# 因为我们要预测下一个 token,所以把 logits 和 labels 错位对齐:
634+
# shifted_prediction_scores: 所有 token 的预测(除了最后一个)
635+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
636+
# labels: 所有 token 的真实值(从第二个开始)
637+
labels = labels[:, 1:].contiguous()
638+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
639+
lm_loss = loss_fct(
640+
shifted_prediction_scores.view(-1, self.config.vocab_size),
641+
labels.view(-1),
642+
)
643+
if reduction == "none":
644+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
645+
...
646+
return CausalLMOutputWithCrossAttentions(
647+
loss=lm_loss,
648+
logits=prediction_scores,
649+
past_key_values=outputs.past_key_values,
650+
hidden_states=outputs.hidden_states,
651+
attentions=outputs.attentions,
652+
cross_attentions=outputs.cross_attentions,
653+
)
654+
```

0 commit comments

Comments
 (0)