@@ -349,26 +349,306 @@ class BertEmbeddings(nn.Module):
349
349
>
350
350
> ![ Multimodal Causal Self-attention Mask] ( 庖丁解牛BLIP2/10.png )
351
351
352
+ ** 视觉编码阶段** :
353
+
354
+ 图像通过视觉编码器(如 ViT)编码为图像特征 image_embeds。Query tokens 通过 cross-attention 吸收图像特征,再通过 self-attention 生成压缩的视觉表示。缓存 query tokens 的 self-attention 的 past_key_values(而非 cross-attention 的 key/value)。
355
+
356
+ QFormer 会使用 past_key_values 缓存和复用 EncoderLayer 中 self-attention 的 key/value :
357
+
358
+ 1 . BertSelfAttention: 自注意力和交叉注意力流程统一化,每次计算后返回本次可能需要缓存的key & value
352
359
353
360
``` python
354
- # #================= Image Captioning ========================##
355
- decoder_input_ids = text_tokens.input_ids.clone()
356
- decoder_input_ids[:, 0 ] = self .tokenizer.bos_token_id
357
- labels = decoder_input_ids.masked_fill(
358
- decoder_input_ids == self .tokenizer.pad_token_id, - 100
361
+ class BertSelfAttention (nn .Module ):
362
+ ...
363
+ def forward (
364
+ self ,
365
+ hidden_states ,
366
+ attention_mask = None ,
367
+ head_mask = None ,
368
+ encoder_hidden_states = None ,
369
+ encoder_attention_mask = None ,
370
+ past_key_value = None ,
371
+ output_attentions = False ,
372
+ ):
373
+
374
+ # 判断是否为交叉注意力
375
+ is_cross_attention = encoder_hidden_states is not None
376
+
377
+ # 交叉注意力则key和value都来自图像,key来自query tokens
378
+ if is_cross_attention:
379
+ key_layer = self .transpose_for_scores(self .key(encoder_hidden_states))
380
+ value_layer = self .transpose_for_scores(self .value(encoder_hidden_states))
381
+ attention_mask = encoder_attention_mask
382
+ # 如果有缓存的key,value传入, 此时先用text embedding计算出key和value
383
+ # 再和缓存的key,value在seq_len的维度拼接起来
384
+ elif past_key_value is not None :
385
+ key_layer = self .transpose_for_scores(self .key(hidden_states))
386
+ value_layer = self .transpose_for_scores(self .value(hidden_states))
387
+ key_layer = torch.cat([past_key_value[0 ], key_layer], dim = 2 ) # (Batch,Heads,Seq_len,Hidden_size)
388
+ value_layer = torch.cat([past_key_value[1 ], value_layer], dim = 2 )
389
+ else :
390
+ # 自注意力
391
+ key_layer = self .transpose_for_scores(self .key(hidden_states))
392
+ value_layer = self .transpose_for_scores(self .value(hidden_states))
393
+
394
+ # 交叉注意力: 传入图像,则q来自query tokens
395
+ # 自注意力: q来自query tokens 或者 text embedding
396
+ mixed_query_layer = self .query(hidden_states)
397
+
398
+ query_layer = self .transpose_for_scores(mixed_query_layer)
399
+
400
+ # * 缓存key和value
401
+ past_key_value = (key_layer, value_layer)
402
+
403
+ # 计算注意力分数
404
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(- 1 , - 2 ))
405
+
406
+ attention_scores = attention_scores / math.sqrt(self .attention_head_size)
407
+ if attention_mask is not None :
408
+ # 应用注意力掩码
409
+ attention_scores = attention_scores + attention_mask
410
+
411
+ # softmax归一化得到注意力概率
412
+ attention_probs = nn.Softmax(dim = - 1 )(attention_scores)
413
+
414
+ if is_cross_attention and self .save_attention:
415
+ self .save_attention_map(attention_probs)
416
+ attention_probs.register_hook(self .save_attn_gradients)
417
+
418
+ # dropout防止过拟合
419
+ attention_probs_dropped = self .dropout(attention_probs)
420
+
421
+ # 计算上下文表示
422
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
423
+
424
+ context_layer = context_layer.permute(0 , 2 , 1 , 3 ).contiguous()
425
+ new_context_layer_shape = context_layer.size()[:- 2 ] + (self .all_head_size,)
426
+ context_layer = context_layer.view(* new_context_layer_shape)
427
+
428
+ outputs = (
429
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
359
430
)
431
+
432
+ # outputs 列表最后一个记录了缓存的key和value
433
+ outputs = outputs + (past_key_value,)
434
+ return outputs
435
+ ```
360
436
361
- query_atts = torch.ones(query_tokens.size()[:- 1 ], dtype = torch.long).to(
362
- image.device
437
+ 2 . BertLayer: 负责组织自注意力和交叉注意力的运算流程
438
+
439
+ ``` python
440
+ class BertLayer (nn .Module ):
441
+ ...
442
+ def forward (
443
+ self ,
444
+ hidden_states , # query tokens
445
+ attention_mask = None , # query token padding mask
446
+ head_mask = None ,
447
+ encoder_hidden_states = None , # image tokens
448
+ encoder_attention_mask = None , # image padding mask
449
+ past_key_value = None ,
450
+ output_attentions = False ,
451
+ query_length = 0 ,
452
+ ):
453
+ self_attn_past_key_value = (
454
+ past_key_value[:2 ] if past_key_value is not None else None
363
455
)
364
- attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim = 1 )
365
- lm_output = self .Qformer(
366
- decoder_input_ids,
367
- attention_mask = attention_mask,
368
- past_key_values = query_output.past_key_values,
456
+ # 自注意力运算
457
+ self_attention_outputs = self .attention(
458
+ hidden_states,
459
+ attention_mask,
460
+ head_mask,
461
+ output_attentions = output_attentions,
462
+ past_key_value = self_attn_past_key_value, # 缓存的key和value
463
+ )
464
+ attention_output = self_attention_outputs[0 ]
465
+ outputs = self_attention_outputs[1 :- 1 ]
466
+
467
+ present_key_value = self_attention_outputs[- 1 ]
468
+
469
+ # 交叉注意力运算
470
+ if query_length > 0 :
471
+ query_attention_output = attention_output[:, :query_length, :]
472
+ if self .has_cross_attention:
473
+ cross_attention_outputs = self .crossattention(
474
+ query_attention_output,
475
+ attention_mask,
476
+ head_mask,
477
+ encoder_hidden_states,
478
+ encoder_attention_mask,
479
+ output_attentions = output_attentions,
480
+ )
481
+ query_attention_output = cross_attention_outputs[0 ]
482
+ outputs = (
483
+ outputs + cross_attention_outputs[1 :- 1 ]
484
+ )
485
+ ...
486
+ outputs = (layer_output,) + outputs
487
+ outputs = outputs + (present_key_value,) # outputs 列表最后一个记录了缓存的key和value
488
+ return outputs
489
+ ```
490
+
491
+ 3 . BertEncoder: 负责组织多个 BertLayer 叠加的运算流程
492
+
493
+ ``` python
494
+ class BertEncoder (nn .Module ):
495
+ ...
496
+ def forward (
497
+ self ,
498
+ hidden_states , # query tokens
499
+ attention_mask = None , # query tokens padding mask
500
+ head_mask = None ,
501
+ encoder_hidden_states = None , # images
502
+ encoder_attention_mask = None , # images padding mask
503
+ past_key_values = None ,
504
+ use_cache = None ,
505
+ output_attentions = False ,
506
+ output_hidden_states = False ,
507
+ return_dict = True ,
508
+ query_length = 0 ,
509
+ ):
510
+ ...
511
+ for i in range (self .config.num_hidden_layers):
512
+ layer_module = self .layer[i]
513
+ ...
514
+ # 如果有缓存,则计算当前层BertLayer时,会从缓存中取出对应层先前缓存的key&value
515
+ past_key_value = past_key_values[i] if past_key_values is not None else None
516
+ layer_outputs = layer_module(
517
+ hidden_states,
518
+ attention_mask,
519
+ layer_head_mask,
520
+ encoder_hidden_states,
521
+ encoder_attention_mask,
522
+ past_key_value,
523
+ output_attentions,
524
+ query_length,
525
+ )
526
+
527
+ hidden_states = layer_outputs[0 ]
528
+ # 每一层BertLayer产生的key&value都会进行缓存
529
+ if use_cache:
530
+ next_decoder_cache += (layer_outputs[- 1 ],)
531
+ ...
532
+ return BaseModelOutputWithPastAndCrossAttentions(
533
+ last_hidden_state = hidden_states,
534
+ past_key_values = next_decoder_cache,
535
+ hidden_states = all_hidden_states,
536
+ attentions = all_self_attentions,
537
+ cross_attentions = all_cross_attentions,
538
+ )
539
+ ```
540
+
541
+ 4 . Image-Grounded Text Generation 学习目标
542
+
543
+ ``` python
544
+ ...
545
+ query_output = self .Qformer.bert(
546
+ query_embeds = query_tokens,
547
+ encoder_hidden_states = image_embeds,
548
+ encoder_attention_mask = image_atts,
549
+ use_cache = True , # 缓存key&value
369
550
return_dict = True ,
370
- labels = labels,
551
+ )
552
+ ...
553
+ # #================= Image Captioning ========================##
554
+ # 这一部分的目标是:根据图像特征,使用 Q-Former 解码器生成文本描述(caption)
555
+
556
+ # Step 1: 准备 decoder 的输入 token IDs
557
+ decoder_input_ids = text_tokens.input_ids.clone()
558
+ # 将第一个 token 替换为 BOS(Begin Of Sentence)标记,表示“开始生成句子”
559
+ decoder_input_ids[:, 0 ] = self .tokenizer.bos_token_id
560
+
561
+ # Step 2: 构造训练目标 labels
562
+ # 将 padding token 替换为 -100,这是 CrossEntropyLoss 默认忽略的标签值
563
+ labels = decoder_input_ids.masked_fill(
564
+ decoder_input_ids == self .tokenizer.pad_token_id, - 100
565
+ )
566
+
567
+ # Step 3: 构建 attention_mask(包含 query tokens 和 文本 token 的 mask)
568
+ # query_atts 是 query tokens 的 attention mask,全为 1(因为都是有效 token)
569
+ query_atts = torch.ones(query_tokens.size()[:- 1 ], dtype = torch.long).to(image.device)
570
+ # 将 query token 的 mask 和文本 token 的 mask 拼接在一起
571
+ attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim = 1 )
572
+
573
+ # Step 4: 调用 Q-Former 解码器进行文本生成
574
+ lm_output = self .Qformer(
575
+ decoder_input_ids, # 输入 token ID 序列(如 [BOS], dog, is...)
576
+ attention_mask = attention_mask, # 指明哪些位置是有效的(非 padding)
577
+ past_key_values = query_output.past_key_values, # 编码器输出的 key/value,包含图像信息
578
+ return_dict = True , # 返回字典格式结果
579
+ labels = labels, # 训练目标,用于计算 loss
580
+ )
581
+
582
+ # Step 5: 提取语言模型损失
583
+ loss_lm = lm_output.loss # 使用交叉熵损失衡量生成与真实之间的差异
584
+ ```
585
+ 5 . BertLMHeadModel: 自回归语言建模任务(如文本生成)
586
+
587
+ ``` python
588
+ class BertLMHeadModel (BertPreTrainedModel ):
589
+ ...
590
+ def forward (
591
+ self ,
592
+ input_ids = None ,
593
+ attention_mask = None ,
594
+ position_ids = None ,
595
+ head_mask = None ,
596
+ query_embeds = None ,
597
+ encoder_hidden_states = None ,
598
+ encoder_attention_mask = None ,
599
+ labels = None ,
600
+ past_key_values = None ,
601
+ use_cache = True ,
602
+ output_attentions = None ,
603
+ output_hidden_states = None ,
604
+ return_dict = None ,
605
+ return_logits = False ,
606
+ is_decoder = True ,
607
+ reduction = " mean" ,
608
+ ):
609
+ ...
610
+ outputs = self .bert(
611
+ input_ids,
612
+ attention_mask = attention_mask,
613
+ position_ids = position_ids,
614
+ head_mask = head_mask,
615
+ query_embeds = query_embeds,
616
+ encoder_hidden_states = encoder_hidden_states,
617
+ encoder_attention_mask = encoder_attention_mask,
618
+ past_key_values = past_key_values,
619
+ use_cache = use_cache,
620
+ output_attentions = output_attentions,
621
+ output_hidden_states = output_hidden_states,
622
+ return_dict = return_dict,
623
+ is_decoder = is_decoder,
371
624
)
372
625
373
- loss_lm = lm_output.loss
374
- ```
626
+ sequence_output = outputs[0 ]
627
+ ...
628
+ # self.cls 是一个分类头(BertOnlyMLMHead),它将每个 token 的向量映射到词汇表空间(logits)
629
+ prediction_scores = self .cls(sequence_output)
630
+ ...
631
+ lm_loss = None
632
+ if labels is not None :
633
+ # 因为我们要预测下一个 token,所以把 logits 和 labels 错位对齐:
634
+ # shifted_prediction_scores: 所有 token 的预测(除了最后一个)
635
+ shifted_prediction_scores = prediction_scores[:, :- 1 , :].contiguous()
636
+ # labels: 所有 token 的真实值(从第二个开始)
637
+ labels = labels[:, 1 :].contiguous()
638
+ loss_fct = CrossEntropyLoss(reduction = reduction, label_smoothing = 0.1 )
639
+ lm_loss = loss_fct(
640
+ shifted_prediction_scores.view(- 1 , self .config.vocab_size),
641
+ labels.view(- 1 ),
642
+ )
643
+ if reduction == " none" :
644
+ lm_loss = lm_loss.view(prediction_scores.size(0 ), - 1 ).sum(1 )
645
+ ...
646
+ return CausalLMOutputWithCrossAttentions(
647
+ loss = lm_loss,
648
+ logits = prediction_scores,
649
+ past_key_values = outputs.past_key_values,
650
+ hidden_states = outputs.hidden_states,
651
+ attentions = outputs.attentions,
652
+ cross_attentions = outputs.cross_attentions,
653
+ )
654
+ ```
0 commit comments