@@ -350,224 +350,25 @@ class BertEmbeddings(nn.Module):
350
350
> ![ Multimodal Causal Self-attention Mask] ( 庖丁解牛BLIP2/10.png )
351
351
352
352
353
-
354
-
355
-
356
-
357
- BertLayer 核心代码实现如下:
358
-
359
353
``` python
360
- class BertLayer (nn .Module ):
361
- def __init__ (self , config , layer_num ):
362
- super ().__init__ ()
363
- ...
364
- self .attention = BertAttention(config)
365
- # 每个几层,添加一个cross attention
366
- if (
367
- self .config.add_cross_attention
368
- and layer_num % self .config.cross_attention_freq == 0
369
- ):
370
- self .crossattention = BertAttention(
371
- config, is_cross_attention = self .config.add_cross_attention
372
- )
373
- self .has_cross_attention = True
374
- else :
375
- self .has_cross_attention = False
376
- self .intermediate = BertIntermediate(config)
377
- self .output = BertOutput(config)
378
-
379
- self .intermediate_query = BertIntermediate(config)
380
- self .output_query = BertOutput(config)
381
-
382
- def forward (
383
- self ,
384
- hidden_states , # query tokens
385
- attention_mask = None , # query token padding mask
386
- head_mask = None ,
387
- encoder_hidden_states = None , # image tokens
388
- encoder_attention_mask = None , # image padding mask
389
- past_key_value = None ,
390
- output_attentions = False ,
391
- query_length = 0 ,
392
- ):
393
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
394
- self_attn_past_key_value = (
395
- past_key_value[:2 ] if past_key_value is not None else None
396
- )
397
- # 如果
398
- self_attention_outputs = self .attention(
399
- hidden_states,
400
- attention_mask,
401
- head_mask,
402
- output_attentions = output_attentions,
403
- past_key_value = self_attn_past_key_value,
354
+ # #================= Image Captioning ========================##
355
+ decoder_input_ids = text_tokens.input_ids.clone()
356
+ decoder_input_ids[:, 0 ] = self .tokenizer.bos_token_id
357
+ labels = decoder_input_ids.masked_fill(
358
+ decoder_input_ids == self .tokenizer.pad_token_id, - 100
404
359
)
405
- attention_output = self_attention_outputs[0 ]
406
- outputs = self_attention_outputs[1 :- 1 ]
407
-
408
- present_key_value = self_attention_outputs[- 1 ]
409
-
410
- if query_length > 0 : # query tokens nums = 32
411
- query_attention_output = attention_output[:, :query_length, :]
412
- # query tokens 经过 attention 后得到 context states
413
- if self .has_cross_attention:
414
- # q (context states) , k 和 v 来自 images encoder
415
- # q(b,32,h) * k(b,seq_len,h).t = (b,32,seq_len)
416
- # score(b,32,seq_len) * v(b,seq_len,h) = (b,32,h)
417
- cross_attention_outputs = self .crossattention(
418
- query_attention_output,
419
- attention_mask,
420
- head_mask,
421
- encoder_hidden_states,
422
- encoder_attention_mask,
423
- output_attentions = output_attentions,
424
- )
425
- query_attention_output = cross_attention_outputs[0 ]
426
- outputs = (
427
- outputs + cross_attention_outputs[1 :- 1 ]
428
- ) # add cross attentions if we output attention weights
429
- # 分块并行处理
430
- layer_output = apply_chunking_to_forward(
431
- self .feed_forward_chunk_query,
432
- self .chunk_size_feed_forward,
433
- self .seq_len_dim,
434
- query_attention_output,
435
- )
436
- if attention_output.shape[1 ] > query_length:
437
- layer_output_text = apply_chunking_to_forward(
438
- self .feed_forward_chunk,
439
- self .chunk_size_feed_forward,
440
- self .seq_len_dim,
441
- attention_output[:, query_length:, :],
442
- )
443
- layer_output = torch.cat([layer_output, layer_output_text], dim = 1 )
444
- else :
445
- layer_output = apply_chunking_to_forward(
446
- self .feed_forward_chunk,
447
- self .chunk_size_feed_forward,
448
- self .seq_len_dim,
449
- attention_output,
450
- )
451
- outputs = (layer_output,) + outputs
452
-
453
- outputs = outputs + (present_key_value,)
454
-
455
- return outputs
456
-
457
- def feed_forward_chunk (self , attention_output ):
458
- intermediate_output = self .intermediate(attention_output)
459
- layer_output = self .output(intermediate_output, attention_output)
460
- return layer_output
461
-
462
- def feed_forward_chunk_query (self , attention_output ):
463
- intermediate_output = self .intermediate_query(attention_output)
464
- layer_output = self .output_query(intermediate_output, attention_output)
465
- return layer_output
466
- ```
467
-
468
- ``` python
469
- class BertAttention (nn .Module ):
470
- def __init__ (self , config , is_cross_attention = False ):
471
- super ().__init__ ()
472
- self .self = BertSelfAttention(config, is_cross_attention)
473
- self .output = BertSelfOutput(config)
474
-
475
- def forward (
476
- self ,
477
- hidden_states ,
478
- attention_mask = None ,
479
- head_mask = None ,
480
- encoder_hidden_states = None ,
481
- encoder_attention_mask = None ,
482
- past_key_value = None ,
483
- output_attentions = False ,
484
- ):
485
- self_outputs = self .self(
486
- hidden_states,
487
- attention_mask,
488
- head_mask,
489
- encoder_hidden_states,
490
- encoder_attention_mask,
491
- past_key_value,
492
- output_attentions,
493
- ) # (context_layer, attention_probs)
494
- attention_output = self .output(self_outputs[0 ], hidden_states)
495
-
496
- outputs = (attention_output,) + self_outputs[
497
- 1 :
498
- ] # add attentions if we output them
499
- return outputs
500
- ```
501
360
502
- ``` python
503
- class BertSelfAttention (nn .Module ):
504
- ...
505
- def transpose_for_scores (self , x ):
506
- # 调整张量形状以便多头注意力计算
507
- new_x_shape = x.size()[:- 1 ] + (
508
- self .num_attention_heads,
509
- self .attention_head_size,
361
+ query_atts = torch.ones(query_tokens.size()[:- 1 ], dtype = torch.long).to(
362
+ image.device
510
363
)
511
- x = x.view(* new_x_shape)
512
- return x.permute(0 , 2 , 1 , 3 )
513
-
514
- def forward (
515
- self ,
516
- hidden_states ,
517
- attention_mask = None ,
518
- head_mask = None ,
519
- encoder_hidden_states = None ,
520
- encoder_attention_mask = None ,
521
- past_key_value = None ,
522
- output_attentions = False ,
523
- ):
524
-
525
- # 判断是否为交叉注意力
526
- is_cross_attention = encoder_hidden_states is not None
527
-
528
- if is_cross_attention: # key and value come from image
529
- key_layer = self .transpose_for_scores(self .key(encoder_hidden_states))
530
- value_layer = self .transpose_for_scores(self .value(encoder_hidden_states))
531
- attention_mask = encoder_attention_mask
532
- elif past_key_value is not None :
533
- key_layer = self .transpose_for_scores(self .key(hidden_states))
534
- value_layer = self .transpose_for_scores(self .value(hidden_states))
535
- key_layer = torch.cat([past_key_value[0 ], key_layer], dim = 2 )
536
- value_layer = torch.cat([past_key_value[1 ], value_layer], dim = 2 )
537
- else :
538
- key_layer = self .transpose_for_scores(self .key(hidden_states))
539
- value_layer = self .transpose_for_scores(self .value(hidden_states))
540
-
541
- mixed_query_layer = self .query(hidden_states)
542
-
543
- query_layer = self .transpose_for_scores(mixed_query_layer)
544
-
545
- past_key_value = (key_layer, value_layer)
546
-
547
- # 计算注意力分数
548
- attention_scores = torch.matmul(query_layer, key_layer.transpose(- 1 , - 2 ))
549
- attention_scores = attention_scores / math.sqrt(self .attention_head_size)
550
- if attention_mask is not None :
551
- # 应用注意力掩码
552
- attention_scores = attention_scores + attention_mask
553
-
554
- # softmax归一化得到注意力概率
555
- attention_probs = nn.Softmax(dim = - 1 )(attention_scores)
556
-
557
- # dropout防止过拟合
558
- attention_probs_dropped = self .dropout(attention_probs)
559
-
560
- # 计算上下文表示
561
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
562
-
563
- context_layer = context_layer.permute(0 , 2 , 1 , 3 ).contiguous()
564
- new_context_layer_shape = context_layer.size()[:- 2 ] + (self .all_head_size,)
565
- context_layer = context_layer.view(* new_context_layer_shape)
566
-
567
- outputs = (
568
- (context_layer, attention_probs) if output_attentions else (context_layer,)
364
+ attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim = 1 )
365
+ lm_output = self .Qformer(
366
+ decoder_input_ids,
367
+ attention_mask = attention_mask,
368
+ past_key_values = query_output.past_key_values,
369
+ return_dict = True ,
370
+ labels = labels,
569
371
)
570
372
571
- outputs = outputs + (past_key_value,)
572
- return outputs
573
- ```
373
+ loss_lm = lm_output.loss
374
+ ```
0 commit comments