@@ -510,3 +510,103 @@ class BertAttention(nn.Module):
510
510
attention_output = self .output(self_outputs, input_tensor)
511
511
return attention_output
512
512
```
513
+
514
+ ## 预训练
515
+
516
+ ### BertPredictionHeadTransform
517
+
518
+ ![ BertPredictionHeadTransform结构图] ( 图解BERT/17.png )
519
+
520
+ ``` python
521
+ class BertPredictionHeadTransform (nn .Module ):
522
+ def __init__ (self , config ):
523
+ super (BertPredictionHeadTransform, self ).__init__ ()
524
+ self .dense = nn.Linear(config.hidden_size, config.hidden_size)
525
+ if isinstance (config.hidden_act, str ) or (sys.version_info[0 ] == 2 and isinstance (config.hidden_act, unicode )):
526
+ self .transform_act_fn = ACT2FN [config.hidden_act]
527
+ else :
528
+ self .transform_act_fn = config.hidden_act
529
+ self .LayerNorm = BertLayerNorm(config.hidden_size, eps = config.layer_norm_eps)
530
+
531
+ def forward (self , hidden_states ):
532
+ hidden_states = self .dense(hidden_states)
533
+ hidden_states = self .transform_act_fn(hidden_states)
534
+ hidden_states = self .LayerNorm(hidden_states)
535
+ return hidden_states
536
+ ```
537
+
538
+ ### BertLMPredictionHead
539
+
540
+ ![ BertLMPredictionHead结构图] ( 图解BERT/18.png )
541
+
542
+ ``` python
543
+ class BertLMPredictionHead (nn .Module ):
544
+ def __init__ (self , config ):
545
+ super (BertLMPredictionHead, self ).__init__ ()
546
+ self .transform = BertPredictionHeadTransform(config)
547
+
548
+ # The output weights are the same as the input embeddings, but there is
549
+ # an output-only bias for each token.
550
+ self .decoder = nn.Linear(config.hidden_size,
551
+ config.vocab_size,
552
+ bias = False )
553
+
554
+ self .bias = nn.Parameter(torch.zeros(config.vocab_size))
555
+
556
+ def forward (self , hidden_states ):
557
+ hidden_states = self .transform(hidden_states)
558
+ hidden_states = self .decoder(hidden_states) + self .bias
559
+ return hidden_states
560
+ ```
561
+
562
+ ### BertPreTrainingHeads
563
+
564
+ ![ BertPreTrainingHeads结构图] ( 图解BERT/19.png )
565
+
566
+ ``` python
567
+ class BertPreTrainingHeads (nn .Module ):
568
+ def __init__ (self , config ):
569
+ super (BertPreTrainingHeads, self ).__init__ ()
570
+ self .predictions = BertLMPredictionHead(config)
571
+ self .seq_relationship = nn.Linear(config.hidden_size, 2 )
572
+
573
+ def forward (self , sequence_output , pooled_output ):
574
+ prediction_scores = self .predictions(sequence_output) #
575
+ seq_relationship_score = self .seq_relationship(pooled_output) # 两个句子是否为上下句关系
576
+ return prediction_scores, seq_relationship_score
577
+ ```
578
+
579
+ ### BertForPreTraining
580
+
581
+ ![ BertForPreTraining结构图] ( 图解BERT/20.png )
582
+
583
+ ``` python
584
+ class BertForPreTraining (BertPreTrainedModel ):
585
+ def __init__ (self , config ):
586
+ super (BertForPreTraining, self ).__init__ (config)
587
+ self .bert = BertModel(config)
588
+ self .cls = BertPreTrainingHeads(config)
589
+
590
+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
591
+ masked_lm_labels = None , next_sentence_label = None ):
592
+
593
+ outputs = self .bert(input_ids,
594
+ attention_mask = attention_mask,
595
+ token_type_ids = token_type_ids,
596
+ position_ids = position_ids,
597
+ head_mask = head_mask)
598
+
599
+ sequence_output, pooled_output = outputs[:2 ] # 隐藏层输出,CLS Token Embeddings
600
+ prediction_scores, seq_relationship_score = self .cls(sequence_output, pooled_output)
601
+
602
+ outputs = (prediction_scores, seq_relationship_score,)
603
+ # 计算掩码语言损失 和 下一个句子预测损失
604
+ if masked_lm_labels is not None and next_sentence_label is not None :
605
+ loss_fct = CrossEntropyLoss(ignore_index = - 1 )
606
+ masked_lm_loss = loss_fct(prediction_scores.view(- 1 , self .config.vocab_size), masked_lm_labels.view(- 1 ))
607
+ next_sentence_loss = loss_fct(seq_relationship_score.view(- 1 , 2 ), next_sentence_label.view(- 1 ))
608
+ total_loss = masked_lm_loss + next_sentence_loss
609
+ outputs = (total_loss,) + outputs
610
+
611
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
612
+ ```
0 commit comments