Skip to content

Commit 9fbde04

Browse files
committed
updates
1 parent 671bd96 commit 9fbde04

File tree

1 file changed

+16
-215
lines changed

1 file changed

+16
-215
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 16 additions & 215 deletions
Original file line numberDiff line numberDiff line change
@@ -350,224 +350,25 @@ class BertEmbeddings(nn.Module):
350350
> ![Multimodal Causal Self-attention Mask](庖丁解牛BLIP2/10.png)
351351
352352

353-
354-
355-
356-
357-
BertLayer 核心代码实现如下:
358-
359353
```python
360-
class BertLayer(nn.Module):
361-
def __init__(self, config, layer_num):
362-
super().__init__()
363-
...
364-
self.attention = BertAttention(config)
365-
# 每个几层,添加一个cross attention
366-
if (
367-
self.config.add_cross_attention
368-
and layer_num % self.config.cross_attention_freq == 0
369-
):
370-
self.crossattention = BertAttention(
371-
config, is_cross_attention=self.config.add_cross_attention
372-
)
373-
self.has_cross_attention = True
374-
else:
375-
self.has_cross_attention = False
376-
self.intermediate = BertIntermediate(config)
377-
self.output = BertOutput(config)
378-
379-
self.intermediate_query = BertIntermediate(config)
380-
self.output_query = BertOutput(config)
381-
382-
def forward(
383-
self,
384-
hidden_states, # query tokens
385-
attention_mask=None, # query token padding mask
386-
head_mask=None,
387-
encoder_hidden_states=None, # image tokens
388-
encoder_attention_mask=None, # image padding mask
389-
past_key_value=None,
390-
output_attentions=False,
391-
query_length=0,
392-
):
393-
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
394-
self_attn_past_key_value = (
395-
past_key_value[:2] if past_key_value is not None else None
396-
)
397-
# 如果
398-
self_attention_outputs = self.attention(
399-
hidden_states,
400-
attention_mask,
401-
head_mask,
402-
output_attentions=output_attentions,
403-
past_key_value=self_attn_past_key_value,
354+
##================= Image Captioning ========================##
355+
decoder_input_ids = text_tokens.input_ids.clone()
356+
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
357+
labels = decoder_input_ids.masked_fill(
358+
decoder_input_ids == self.tokenizer.pad_token_id, -100
404359
)
405-
attention_output = self_attention_outputs[0]
406-
outputs = self_attention_outputs[1:-1]
407-
408-
present_key_value = self_attention_outputs[-1]
409-
410-
if query_length > 0: # query tokens nums = 32
411-
query_attention_output = attention_output[:, :query_length, :]
412-
# query tokens 经过 attention 后得到 context states
413-
if self.has_cross_attention:
414-
# q (context states) , k 和 v 来自 images encoder
415-
# q(b,32,h) * k(b,seq_len,h).t = (b,32,seq_len)
416-
# score(b,32,seq_len) * v(b,seq_len,h) = (b,32,h)
417-
cross_attention_outputs = self.crossattention(
418-
query_attention_output,
419-
attention_mask,
420-
head_mask,
421-
encoder_hidden_states,
422-
encoder_attention_mask,
423-
output_attentions=output_attentions,
424-
)
425-
query_attention_output = cross_attention_outputs[0]
426-
outputs = (
427-
outputs + cross_attention_outputs[1:-1]
428-
) # add cross attentions if we output attention weights
429-
# 分块并行处理
430-
layer_output = apply_chunking_to_forward(
431-
self.feed_forward_chunk_query,
432-
self.chunk_size_feed_forward,
433-
self.seq_len_dim,
434-
query_attention_output,
435-
)
436-
if attention_output.shape[1] > query_length:
437-
layer_output_text = apply_chunking_to_forward(
438-
self.feed_forward_chunk,
439-
self.chunk_size_feed_forward,
440-
self.seq_len_dim,
441-
attention_output[:, query_length:, :],
442-
)
443-
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
444-
else:
445-
layer_output = apply_chunking_to_forward(
446-
self.feed_forward_chunk,
447-
self.chunk_size_feed_forward,
448-
self.seq_len_dim,
449-
attention_output,
450-
)
451-
outputs = (layer_output,) + outputs
452-
453-
outputs = outputs + (present_key_value,)
454-
455-
return outputs
456-
457-
def feed_forward_chunk(self, attention_output):
458-
intermediate_output = self.intermediate(attention_output)
459-
layer_output = self.output(intermediate_output, attention_output)
460-
return layer_output
461-
462-
def feed_forward_chunk_query(self, attention_output):
463-
intermediate_output = self.intermediate_query(attention_output)
464-
layer_output = self.output_query(intermediate_output, attention_output)
465-
return layer_output
466-
```
467-
468-
```python
469-
class BertAttention(nn.Module):
470-
def __init__(self, config, is_cross_attention=False):
471-
super().__init__()
472-
self.self = BertSelfAttention(config, is_cross_attention)
473-
self.output = BertSelfOutput(config)
474-
475-
def forward(
476-
self,
477-
hidden_states,
478-
attention_mask=None,
479-
head_mask=None,
480-
encoder_hidden_states=None,
481-
encoder_attention_mask=None,
482-
past_key_value=None,
483-
output_attentions=False,
484-
):
485-
self_outputs = self.self(
486-
hidden_states,
487-
attention_mask,
488-
head_mask,
489-
encoder_hidden_states,
490-
encoder_attention_mask,
491-
past_key_value,
492-
output_attentions,
493-
) # (context_layer, attention_probs)
494-
attention_output = self.output(self_outputs[0], hidden_states)
495-
496-
outputs = (attention_output,) + self_outputs[
497-
1:
498-
] # add attentions if we output them
499-
return outputs
500-
```
501360

502-
```python
503-
class BertSelfAttention(nn.Module):
504-
...
505-
def transpose_for_scores(self, x):
506-
# 调整张量形状以便多头注意力计算
507-
new_x_shape = x.size()[:-1] + (
508-
self.num_attention_heads,
509-
self.attention_head_size,
361+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
362+
image.device
510363
)
511-
x = x.view(*new_x_shape)
512-
return x.permute(0, 2, 1, 3)
513-
514-
def forward(
515-
self,
516-
hidden_states,
517-
attention_mask=None,
518-
head_mask=None,
519-
encoder_hidden_states=None,
520-
encoder_attention_mask=None,
521-
past_key_value=None,
522-
output_attentions=False,
523-
):
524-
525-
# 判断是否为交叉注意力
526-
is_cross_attention = encoder_hidden_states is not None
527-
528-
if is_cross_attention: # key and value come from image
529-
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
530-
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
531-
attention_mask = encoder_attention_mask
532-
elif past_key_value is not None:
533-
key_layer = self.transpose_for_scores(self.key(hidden_states))
534-
value_layer = self.transpose_for_scores(self.value(hidden_states))
535-
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
536-
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
537-
else:
538-
key_layer = self.transpose_for_scores(self.key(hidden_states))
539-
value_layer = self.transpose_for_scores(self.value(hidden_states))
540-
541-
mixed_query_layer = self.query(hidden_states)
542-
543-
query_layer = self.transpose_for_scores(mixed_query_layer)
544-
545-
past_key_value = (key_layer, value_layer)
546-
547-
# 计算注意力分数
548-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
549-
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
550-
if attention_mask is not None:
551-
# 应用注意力掩码
552-
attention_scores = attention_scores + attention_mask
553-
554-
# softmax归一化得到注意力概率
555-
attention_probs = nn.Softmax(dim=-1)(attention_scores)
556-
557-
# dropout防止过拟合
558-
attention_probs_dropped = self.dropout(attention_probs)
559-
560-
# 计算上下文表示
561-
context_layer = torch.matmul(attention_probs_dropped, value_layer)
562-
563-
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
564-
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
565-
context_layer = context_layer.view(*new_context_layer_shape)
566-
567-
outputs = (
568-
(context_layer, attention_probs) if output_attentions else (context_layer,)
364+
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
365+
lm_output = self.Qformer(
366+
decoder_input_ids,
367+
attention_mask=attention_mask,
368+
past_key_values=query_output.past_key_values,
369+
return_dict=True,
370+
labels=labels,
569371
)
570372

571-
outputs = outputs + (past_key_value,)
572-
return outputs
573-
```
373+
loss_lm = lm_output.loss
374+
```

0 commit comments

Comments
 (0)