Skip to content

Commit a984349

Browse files
committed
updates
1 parent adeebae commit a984349

File tree

3 files changed

+363
-1
lines changed

3 files changed

+363
-1
lines changed

src/LLM/图解BERT.md

Lines changed: 363 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ class BertAttention(nn.Module):
513513

514514
## 预训练
515515

516+
![预训练与微调](图解BERT/22.png)
517+
516518
### BertPredictionHeadTransform
517519

518520
![BertPredictionHeadTransform结构图](图解BERT/17.png)
@@ -609,4 +611,364 @@ class BertForPreTraining(BertPreTrainedModel):
609611
outputs = (total_loss,) + outputs
610612

611613
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
612-
```
614+
```
615+
616+
## 其他下游任务
617+
618+
![Bert支持的下游任务图](图解BERT/21.png)
619+
620+
### 问答任务
621+
622+
在 BERT 的问答任务中,典型的输入是一个包含 **问题(Question)****上下文(Context)** 的文本对。例如:
623+
624+
> **问题**: “谁写了《哈姆雷特》?”
625+
> **上下文**: “莎士比亚是英国文学史上最伟大的作家之一,他写了包括《哈姆雷特》、《麦克白》等著名悲剧。”
626+
627+
1. 输入格式(Tokenization 后的形式),在使用 `BertTokenizer` 编码后,输入会变成如下结构:
628+
629+
```json
630+
[CLS] 问题 tokens [SEP] 上下文 tokens [SEP]
631+
```
632+
2. BERT 的输出(Outputs),通过调用 `self.bert(...)`,你将得到一个包含多个元素的 tuple 输出:
633+
634+
```python
635+
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
636+
```
637+
638+
返回值形如:
639+
640+
```python
641+
(
642+
sequence_output, # (batch_size, seq_length, hidden_size)
643+
pooled_output, # (batch_size, hidden_size)
644+
)
645+
```
646+
主要输出项解释:
647+
648+
`sequence_output`: 最终每个 token 的表示
649+
650+
- 形状:`(batch_size, seq_length, hidden_size)`
651+
- 是模型最后一层所有 token(包括问题和上下文)的隐藏状态。
652+
- 在问答任务中,我们主要使用它来预测答案的起始和结束位置。
653+
654+
`pooled_output`: 句子级别表示(不常用)
655+
656+
- 形状:`(batch_size, hidden_size)`
657+
-`[CLS]` token 经过一层全连接后的输出。
658+
- 在分类任务中更有用,在问答任务中一般不会使用这个输出。
659+
660+
3. 如何利用 BERT 输出做问答预测?
661+
662+
`BertForQuestionAnswering` 中,使用了如下逻辑:
663+
664+
```python
665+
logits = self.qa_outputs(sequence_output) # (batch_size, seq_length, 2)
666+
start_logits, end_logits = logits.split(1, dim=-1) # split into start and end
667+
start_logits = start_logits.squeeze(-1) # (batch_size, seq_length)
668+
end_logits = end_logits.squeeze(-1)
669+
```
670+
`qa_outputs` 层的作用:
671+
- 是一个线性层:`nn.Linear(config.hidden_size, 2)`
672+
- 将每个 token 的 `hidden_size` 向量映射成两个分数:一个是该 token 作为答案开始的可能性,另一个是作为答案结束的可能性。
673+
674+
输出解释:
675+
- `start_logits`: 每个 token 是答案起点的得分(未归一化)。
676+
- `end_logits`: 每个 token 是答案终点的得分。
677+
678+
比如对于一个长度为 128 的序列,每个 token 都有一个对应的 start/end 分数:
679+
680+
```python
681+
start_scores = torch.softmax(start_logits, dim=-1) # softmax 得到概率
682+
end_scores = torch.softmax(end_logits, dim=-1)
683+
684+
# 找出最可能是 start 和 end 的位置
685+
start_index = torch.argmax(start_scores)
686+
end_index = torch.argmax(end_scores)
687+
```
688+
689+
如果 `start_index <= end_index`,那么可以组合这两个索引得到答案 span。
690+
691+
692+
#### 代码实现
693+
694+
```python
695+
class BertForQuestionAnswering(BertPreTrainedModel):
696+
def __init__(self, config):
697+
super(BertForQuestionAnswering, self).__init__(config)
698+
self.num_labels = config.num_labels # 通常是 2,即 start 和 end
699+
self.bert = BertModel(config)
700+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
701+
702+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
703+
start_positions=None, end_positions=None):
704+
705+
outputs = self.bert(input_ids,
706+
attention_mask=attention_mask,
707+
token_type_ids=token_type_ids,
708+
position_ids=position_ids)
709+
710+
sequence_output = outputs[0]
711+
# (batch,seq_len,hidden_size) ---> (batch,seq_len,2)
712+
logits = self.qa_outputs(sequence_output)
713+
714+
start_logits, end_logits = logits.split(1, dim=-1)
715+
start_logits = start_logits.squeeze(-1) # (batch,seq_len)
716+
end_logits = end_logits.squeeze(-1)
717+
718+
outputs = (start_logits, end_logits,)
719+
# 计算交叉熵损失
720+
if start_positions is not None and end_positions is not None:
721+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
722+
# ignored_index = seq_len
723+
ignored_index = start_logits.size(1)
724+
# clamp_ 是 PyTorch 中的一个方法,用于将张量中的值限制在指定的范围内。
725+
# 它的语法是 tensor.clamp_(min, max) ,表示将张量中的值限制在 min 和 max 之间。
726+
# 如果值小于 min ,则将其设置为 min ;如果值大于 max ,则将其设置为 max 。
727+
start_positions.clamp_(0, ignored_index)
728+
end_positions.clamp_(0, ignored_index)
729+
730+
# ignore_index: 用于指定在计算损失时忽略的标签索引。
731+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
732+
# 分别计算答案起始下标和结束下标预测得到的交叉熵损失
733+
start_loss = loss_fct(start_logits, start_positions)
734+
end_loss = loss_fct(end_logits, end_positions)
735+
total_loss = (start_loss + end_loss) / 2
736+
outputs = (total_loss,) + outputs
737+
738+
return outputs # (loss), start_logits, end_logits
739+
740+
```
741+
742+
#### 易混淆
743+
744+
BERT 是一个 **基于上下文编码(Contextual Encoder)** 的模型,不是自回归生成器。它不会“生成”新的文本,而是对输入文本中每个 token 的角色进行分类(如判断哪个是答案的开始、结束)。所以最终的答案只能来自原始输入文本中的某一段子串。
745+
746+
📚 详细解释
747+
748+
1. ✅ BERT 是一个 Encoder-only 模型
749+
750+
- BERT 只包含 Transformer 的 encoder 部分。
751+
752+
- 它的作用是给定一个完整的句子(或两个句子),对每个 token 生成一个上下文相关的表示(contextualized representation)。
753+
754+
-**不具有生成能力**,不能像 GPT 这样的 decoder-only 模型那样逐词生成新内容。
755+
756+
---
757+
758+
2. 🔍 QA 任务的本质:定位答案 span 而非生成答案
759+
760+
在 SQuAD 这类抽取式问答任务中:
761+
762+
- 答案必须是原文中的连续片段(span)。
763+
764+
- 所以模型的任务是:
765+
766+
- 给出问题和上下文;
767+
768+
- 在上下文中找到最可能的答案起始位置和结束位置;
769+
770+
- 最终答案就是上下文中这两个位置之间的字符串。
771+
772+
BERT 做的就是这个定位任务,而不是重新生成一个新的答案。
773+
774+
---
775+
776+
3. 🧩 输入与输出的关系
777+
778+
```python
779+
answer_tokens = input_ids[0][start_index : end_index + 1]
780+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
781+
```
782+
783+
这段代码的意思是:
784+
785+
- `start_index``end_index` 是模型预测出的答案的起始和结束位置。
786+
787+
- 我们从原始输入的 `input_ids` 中取出对应的 token ID 子序列。
788+
789+
- 使用 tokenizer 把这些 token ID 解码成自然语言文本。
790+
791+
- 得到的就是答案。
792+
793+
这其实就是在说:
794+
795+
> “根据你的理解,答案应该在这段文字中的第 X 到第 Y 个词之间,请把这部分原文告诉我。”
796+
797+
---
798+
799+
4. 🧪 举个例子
800+
801+
假设原始上下文是:
802+
803+
```
804+
The capital of France is Paris.
805+
```
806+
807+
经过 Tokenizer 编码后可能是:
808+
809+
```
810+
[CLS] the capital of france is paris [SEP]
811+
```
812+
如果模型预测 start_index=5,end_index=5,那么对应的就是单词 `"paris"`,这就是答案。
813+
814+
---
815+
816+
⚠️ 注意事项
817+
818+
1. **不能超出上下文范围**
819+
- start/end positions 必须落在上下文部分(即 token_type_id == 1 的区域)。
820+
- 否则答案可能不合理(比如取到了问题部分的内容)。
821+
822+
2. **特殊 token 不计入答案**
823+
- `[CLS]`, `[SEP]` 等会被 `skip_special_tokens=True` 自动跳过。
824+
825+
3. **无法处理不在原文中的答案**
826+
- 如果正确答案没有出现在上下文中,BERT 无法“编造”出来。
827+
- 这是抽取式问答模型的局限性。
828+
829+
---
830+
831+
💡 对比:生成式 vs 抽取式问答
832+
833+
| 类型 | 模型代表 | 是否能生成新文本 | 答案是否必须在原文中 | 示例 |
834+
|------|----------|------------------|-----------------------|------|
835+
| 抽取式 | BERT ||| 答案是原文中的一段 |
836+
| 生成式 | T5 / BART / GPT ||| 答案可以是任意文本 |
837+
838+
如果你希望模型能“自己写答案”,那就需要使用生成式模型。
839+
840+
---
841+
842+
✅ 总结
843+
844+
| 问题 | 回答 |
845+
|------|------|
846+
| 为什么答案来自 `input_ids`| 因为 BERT 是编码器模型,只做抽取式问答,答案必须是原文中的一段文本。 |
847+
| BERT 能不能自己生成答案? | 不能,BERT 不具备生成能力,只能对输入文本中的 token 做分类。 |
848+
| 如何获取答案? | 根据预测的 start/end index,从 `input_ids` 中提取 token,并用 tokenizer 解码成自然语言。 |
849+
850+
851+
### Token分类任务
852+
853+
Token 分类任务是指对输入文本中的每个 token 进行分类,常见的应用场景包括:
854+
855+
- 命名实体识别 (NER)
856+
- 词性标注 (POS)
857+
- 语义角色标注 (SRL)
858+
859+
```python
860+
class BertForTokenClassification(BertPreTrainedModel):
861+
def __init__(self, config):
862+
super(BertForTokenClassification, self).__init__(config)
863+
self.num_labels = config.num_labels
864+
self.bert = BertModel(config)
865+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
866+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
867+
868+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
869+
position_ids=None, head_mask=None, labels=None):
870+
871+
outputs = self.bert(input_ids,
872+
attention_mask=attention_mask,
873+
token_type_ids=token_type_ids,
874+
position_ids=position_ids,
875+
head_mask=head_mask)
876+
877+
sequence_output = outputs[0] # (batch,seq_len,hidden_size)
878+
879+
sequence_output = self.dropout(sequence_output)
880+
logits = self.classifier(sequence_output) # (batch,seq_len,num_labels)
881+
882+
outputs = (logits,)
883+
if labels is not None:
884+
loss_fct = CrossEntropyLoss()
885+
# Only keep active parts of the loss
886+
if attention_mask is not None:
887+
active_loss = attention_mask.view(-1) == 1
888+
active_logits = logits.view(-1, self.num_labels)[active_loss]
889+
active_labels = labels.view(-1)[active_loss]
890+
loss = loss_fct(active_logits, active_labels)
891+
else:
892+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
893+
outputs = (loss,) + outputs
894+
895+
return outputs # (loss), scores
896+
```
897+
### 多项选择任务
898+
899+
多项选择任务是指给定一个问题和多个候选答案,模型需要从中选择最合适的答案。常见的应用场景包括:
900+
901+
- 阅读理解任务
902+
903+
- 问答系统中的候选答案选择
904+
905+
- 对话系统中的候选回复选择
906+
907+
908+
在 多项选择题(Multiple Choice) 任务中,BERT 的输入组织形式与普通分类或问答任务略有不同。你需要为每个选项分别构造一个完整的 BERT 输入序列,并将它们组合成一个批次进行处理。
909+
910+
✅ 假设你有一个问题 + 4 个选项:
911+
912+
```json
913+
问题:谁写了《哈姆雷特》?
914+
A. 雨果
915+
B. 歌德
916+
C. 莎士比亚
917+
D. 托尔斯泰
918+
```
919+
920+
对于这样的多选问题,BERT 的输入方式是:
921+
922+
对每一个选项,都单独构造一个 `[CLS] + 问题 + [SEP] + 选项内容 + [SEP]` 的输入序列。
923+
924+
也就是说,模型会对每个选项分别编码 ,然后从中选出最合适的那个。
925+
926+
```python
927+
class BertForMultipleChoice(BertPreTrainedModel):
928+
def __init__(self, config):
929+
super(BertForMultipleChoice, self).__init__(config)
930+
self.bert = BertModel(config)
931+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
932+
self.classifier = nn.Linear(config.hidden_size, 1)
933+
934+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
935+
position_ids=None, head_mask=None, labels=None):
936+
# 获取选项个数
937+
num_choices = input_ids.shape[1] # (batch_size, num_choices, seq_length)
938+
# 将选项展平,以便一起处理: (batch_size * num_choices, seq_length)
939+
input_ids = input_ids.view(-1, input_ids.size(-1))
940+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
941+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
942+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
943+
944+
outputs = self.bert(input_ids,
945+
attention_mask=attention_mask,
946+
token_type_ids=token_type_ids,
947+
position_ids=position_ids,
948+
head_mask=head_mask)
949+
950+
pooled_output = outputs[1] # (batch_size * num_choices, hidden_size)
951+
952+
pooled_output = self.dropout(pooled_output)
953+
logits = self.classifier(pooled_output) # (batch_size * num_choices, 1)
954+
reshaped_logits = logits.view(-1, num_choices) # (batch_size , num_choices, 1)
955+
956+
outputs = (reshaped_logits,)
957+
958+
if labels is not None:
959+
loss_fct = CrossEntropyLoss()
960+
loss = loss_fct(reshaped_logits, labels)
961+
outputs = (loss,) + outputs
962+
963+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
964+
```
965+
在前向传播中,会将这些输入展平,变成:
966+
967+
```python
968+
input_ids.view(-1, seq_length) # (batch_size * num_choices, seq_length)
969+
```
970+
971+
这样就能让 BERT 对每个选项分别进行编码。
972+
973+
BERT 输出后,再对每个选项做分类打分,最后重新 reshape 成 (batch_size, num_choices) 形式,用于计算交叉熵损失。
974+

src/LLM/图解BERT/21.png

211 KB
Loading

src/LLM/图解BERT/22.png

182 KB
Loading

0 commit comments

Comments
 (0)