Skip to content

Commit 18f4871

Browse files
committed
updates
1 parent f24cfe7 commit 18f4871

File tree

5 files changed

+100
-0
lines changed

5 files changed

+100
-0
lines changed

src/LLM/图解BERT.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,103 @@ class BertAttention(nn.Module):
510510
attention_output = self.output(self_outputs, input_tensor)
511511
return attention_output
512512
```
513+
514+
## 预训练
515+
516+
### BertPredictionHeadTransform
517+
518+
![BertPredictionHeadTransform结构图](图解BERT/17.png)
519+
520+
```python
521+
class BertPredictionHeadTransform(nn.Module):
522+
def __init__(self, config):
523+
super(BertPredictionHeadTransform, self).__init__()
524+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
525+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
526+
self.transform_act_fn = ACT2FN[config.hidden_act]
527+
else:
528+
self.transform_act_fn = config.hidden_act
529+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
530+
531+
def forward(self, hidden_states):
532+
hidden_states = self.dense(hidden_states)
533+
hidden_states = self.transform_act_fn(hidden_states)
534+
hidden_states = self.LayerNorm(hidden_states)
535+
return hidden_states
536+
```
537+
538+
### BertLMPredictionHead
539+
540+
![BertLMPredictionHead结构图](图解BERT/18.png)
541+
542+
```python
543+
class BertLMPredictionHead(nn.Module):
544+
def __init__(self, config):
545+
super(BertLMPredictionHead, self).__init__()
546+
self.transform = BertPredictionHeadTransform(config)
547+
548+
# The output weights are the same as the input embeddings, but there is
549+
# an output-only bias for each token.
550+
self.decoder = nn.Linear(config.hidden_size,
551+
config.vocab_size,
552+
bias=False)
553+
554+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
555+
556+
def forward(self, hidden_states):
557+
hidden_states = self.transform(hidden_states)
558+
hidden_states = self.decoder(hidden_states) + self.bias
559+
return hidden_states
560+
```
561+
562+
### BertPreTrainingHeads
563+
564+
![BertPreTrainingHeads结构图](图解BERT/19.png)
565+
566+
```python
567+
class BertPreTrainingHeads(nn.Module):
568+
def __init__(self, config):
569+
super(BertPreTrainingHeads, self).__init__()
570+
self.predictions = BertLMPredictionHead(config)
571+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
572+
573+
def forward(self, sequence_output, pooled_output):
574+
prediction_scores = self.predictions(sequence_output) #
575+
seq_relationship_score = self.seq_relationship(pooled_output) # 两个句子是否为上下句关系
576+
return prediction_scores, seq_relationship_score
577+
```
578+
579+
### BertForPreTraining
580+
581+
![BertForPreTraining结构图](图解BERT/20.png)
582+
583+
```python
584+
class BertForPreTraining(BertPreTrainedModel):
585+
def __init__(self, config):
586+
super(BertForPreTraining, self).__init__(config)
587+
self.bert = BertModel(config)
588+
self.cls = BertPreTrainingHeads(config)
589+
590+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
591+
masked_lm_labels=None, next_sentence_label=None):
592+
593+
outputs = self.bert(input_ids,
594+
attention_mask=attention_mask,
595+
token_type_ids=token_type_ids,
596+
position_ids=position_ids,
597+
head_mask=head_mask)
598+
599+
sequence_output, pooled_output = outputs[:2] # 隐藏层输出,CLS Token Embeddings
600+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
601+
602+
outputs = (prediction_scores, seq_relationship_score,)
603+
# 计算掩码语言损失 和 下一个句子预测损失
604+
if masked_lm_labels is not None and next_sentence_label is not None:
605+
loss_fct = CrossEntropyLoss(ignore_index=-1)
606+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
607+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
608+
total_loss = masked_lm_loss + next_sentence_loss
609+
outputs = (total_loss,) + outputs
610+
611+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
612+
```

src/LLM/图解BERT/17.png

102 KB
Loading

src/LLM/图解BERT/18.png

202 KB
Loading

src/LLM/图解BERT/19.png

286 KB
Loading

src/LLM/图解BERT/20.png

1.86 MB
Loading

0 commit comments

Comments
 (0)