Skip to content

Commit fcc0be8

Browse files
committed
updates
1 parent 9fbde04 commit fcc0be8

File tree

1 file changed

+200
-1
lines changed

1 file changed

+200
-1
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,207 @@ class BertEmbeddings(nn.Module):
349349
>
350350
> ![Multimodal Causal Self-attention Mask](庖丁解牛BLIP2/10.png)
351351
352+
**视觉编码阶段**:
353+
354+
图像通过视觉编码器(如 ViT)编码为图像特征 image_embeds。Query tokens 通过 cross-attention 吸收图像特征,再通过 self-attention 生成压缩的视觉表示。缓存 query tokens 的 self-attention 的 past_key_values(而非 cross-attention 的 key/value)。
355+
356+
QFormer 会使用 past_key_values 缓存和复用 EncoderLayer 中 self-attention 的 key/value :
357+
358+
1. BertSelfAttention: 自注意力和交叉注意力流程统一化,每次计算后返回本次可能需要缓存的key & value
359+
360+
```python
361+
class BertSelfAttention(nn.Module):
362+
...
363+
def forward(
364+
self,
365+
hidden_states,
366+
attention_mask=None,
367+
head_mask=None,
368+
encoder_hidden_states=None,
369+
encoder_attention_mask=None,
370+
past_key_value=None,
371+
output_attentions=False,
372+
):
373+
374+
# 判断是否为交叉注意力
375+
is_cross_attention = encoder_hidden_states is not None
376+
377+
# 交叉注意力则key和value都来自图像,key来自query tokens
378+
if is_cross_attention:
379+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
380+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
381+
attention_mask = encoder_attention_mask
382+
# 如果有缓存的key,value传入, 此时先用text embedding计算出key和value
383+
# 再和缓存的key,value在seq_len的维度拼接起来
384+
elif past_key_value is not None:
385+
key_layer = self.transpose_for_scores(self.key(hidden_states))
386+
value_layer = self.transpose_for_scores(self.value(hidden_states))
387+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) # (Batch,Heads,Seq_len,Hidden_size)
388+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
389+
else:
390+
# 自注意力
391+
key_layer = self.transpose_for_scores(self.key(hidden_states))
392+
value_layer = self.transpose_for_scores(self.value(hidden_states))
393+
394+
# 交叉注意力: 传入图像,则q来自query tokens
395+
# 自注意力: q来自query tokens 或者 text embedding
396+
mixed_query_layer = self.query(hidden_states)
397+
398+
query_layer = self.transpose_for_scores(mixed_query_layer)
399+
400+
# * 缓存key和value
401+
past_key_value = (key_layer, value_layer)
402+
403+
# 计算注意力分数
404+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
405+
406+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
407+
if attention_mask is not None:
408+
# 应用注意力掩码
409+
attention_scores = attention_scores + attention_mask
410+
411+
# softmax归一化得到注意力概率
412+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
413+
414+
if is_cross_attention and self.save_attention:
415+
self.save_attention_map(attention_probs)
416+
attention_probs.register_hook(self.save_attn_gradients)
417+
418+
# dropout防止过拟合
419+
attention_probs_dropped = self.dropout(attention_probs)
420+
421+
# 计算上下文表示
422+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
423+
424+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
425+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
426+
context_layer = context_layer.view(*new_context_layer_shape)
427+
428+
outputs = (
429+
(context_layer, attention_probs) if output_attentions else (context_layer,)
430+
)
431+
432+
# outputs 列表最后一个记录了缓存的key和value
433+
outputs = outputs + (past_key_value,)
434+
return outputs
435+
```
436+
437+
2. BertLayer: 负责组织自注意力和交叉注意力的运算流程
438+
439+
```python
440+
class BertLayer(nn.Module):
441+
...
442+
def forward(
443+
self,
444+
hidden_states, # query tokens
445+
attention_mask=None, # query token padding mask
446+
head_mask=None,
447+
encoder_hidden_states=None, # image tokens
448+
encoder_attention_mask=None, # image padding mask
449+
past_key_value=None,
450+
output_attentions=False,
451+
query_length=0,
452+
):
453+
self_attn_past_key_value = (
454+
past_key_value[:2] if past_key_value is not None else None
455+
)
456+
# 自注意力运算
457+
self_attention_outputs = self.attention(
458+
hidden_states,
459+
attention_mask,
460+
head_mask,
461+
output_attentions=output_attentions,
462+
past_key_value=self_attn_past_key_value, # 缓存的key和value
463+
)
464+
attention_output = self_attention_outputs[0]
465+
outputs = self_attention_outputs[1:-1]
466+
467+
present_key_value = self_attention_outputs[-1]
468+
469+
# 交叉注意力运算
470+
if query_length > 0:
471+
query_attention_output = attention_output[:, :query_length, :]
472+
if self.has_cross_attention:
473+
cross_attention_outputs = self.crossattention(
474+
query_attention_output,
475+
attention_mask,
476+
head_mask,
477+
encoder_hidden_states,
478+
encoder_attention_mask,
479+
output_attentions=output_attentions,
480+
)
481+
query_attention_output = cross_attention_outputs[0]
482+
outputs = (
483+
outputs + cross_attention_outputs[1:-1]
484+
)
485+
...
486+
outputs = (layer_output,) + outputs
487+
outputs = outputs + (present_key_value,) # outputs 列表最后一个记录了缓存的key和value
488+
return outputs
489+
```
490+
491+
3. BertEncoder: 负责组织多个 BertLayer 叠加的运算流程
352492

353493
```python
494+
class BertEncoder(nn.Module):
495+
...
496+
def forward(
497+
self,
498+
hidden_states, # query tokens
499+
attention_mask=None, # query tokens padding mask
500+
head_mask=None,
501+
encoder_hidden_states=None, # images
502+
encoder_attention_mask=None, # images padding mask
503+
past_key_values=None,
504+
use_cache=None,
505+
output_attentions=False,
506+
output_hidden_states=False,
507+
return_dict=True,
508+
query_length=0,
509+
):
510+
...
511+
for i in range(self.config.num_hidden_layers):
512+
layer_module = self.layer[i]
513+
...
514+
# 如果有缓存,则计算当前层BertLayer时,会从缓存中取出对应层先前缓存的key&value
515+
past_key_value = past_key_values[i] if past_key_values is not None else None
516+
layer_outputs = layer_module(
517+
hidden_states,
518+
attention_mask,
519+
layer_head_mask,
520+
encoder_hidden_states,
521+
encoder_attention_mask,
522+
past_key_value,
523+
output_attentions,
524+
query_length,
525+
)
526+
527+
hidden_states = layer_outputs[0]
528+
# 每一层BertLayer产生的key&value都会进行缓存
529+
if use_cache:
530+
next_decoder_cache += (layer_outputs[-1],)
531+
...
532+
return BaseModelOutputWithPastAndCrossAttentions(
533+
last_hidden_state=hidden_states,
534+
past_key_values=next_decoder_cache,
535+
hidden_states=all_hidden_states,
536+
attentions=all_self_attentions,
537+
cross_attentions=all_cross_attentions,
538+
)
539+
```
540+
541+
4. Image-Grounded Text Generation 学习目标
542+
543+
```python
544+
...
545+
query_output = self.Qformer.bert(
546+
query_embeds=query_tokens,
547+
encoder_hidden_states=image_embeds,
548+
encoder_attention_mask=image_atts,
549+
use_cache=True, # 缓存key&value
550+
return_dict=True,
551+
)
552+
...
354553
##================= Image Captioning ========================##
355554
decoder_input_ids = text_tokens.input_ids.clone()
356555
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
@@ -365,7 +564,7 @@ class BertEmbeddings(nn.Module):
365564
lm_output = self.Qformer(
366565
decoder_input_ids,
367566
attention_mask=attention_mask,
368-
past_key_values=query_output.past_key_values,
567+
past_key_values=query_output.past_key_values, # 传入缓存的每一层key&value
369568
return_dict=True,
370569
labels=labels,
371570
)

0 commit comments

Comments
 (0)