Skip to content

Commit f24cfe7

Browse files
committed
updates
1 parent 831f754 commit f24cfe7

File tree

9 files changed

+97
-3
lines changed

9 files changed

+97
-3
lines changed

src/LLM/图解BERT.md

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ class BertLayer(nn.Module):
302302
self.output = BertOutput(config)
303303

304304
def forward(self, hidden_states, attention_mask=None):
305-
attention_outputs = self.attention(hidden_states, attention_mask)
306-
attention_output = attention_outputs[0]
305+
attention_output = self.attention(hidden_states, attention_mask)
307306
intermediate_output = self.intermediate(attention_output)
308307
layer_output = self.output(intermediate_output, attention_output)
309308
return layer_output
@@ -415,4 +414,99 @@ class BertForSequenceClassification(BertPreTrainedModel):
415414
outputs = (loss,) + outputs
416415

417416
return outputs # (loss), logits, (hidden_states), (attentions)
418-
```
417+
```
418+
419+
### BertAttention
420+
421+
#### BertSelfAttention
422+
423+
![多头自注意力计算流程图](图解BERT/14.png)
424+
425+
```python
426+
class BertSelfAttention(nn.Module):
427+
def __init__(self, config):
428+
super(BertSelfAttention, self).__init__()
429+
self.output_attentions = config.output_attentions
430+
431+
self.num_attention_heads = config.num_attention_heads
432+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
433+
self.all_head_size = self.num_attention_heads * self.attention_head_size
434+
435+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
436+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
437+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
438+
439+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
440+
441+
def transpose_for_scores(self, x):
442+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
443+
x = x.view(*new_x_shape)
444+
return x.permute(0, 2, 1, 3)
445+
446+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
447+
mixed_query_layer = self.query(hidden_states)
448+
mixed_key_layer = self.key(hidden_states)
449+
mixed_value_layer = self.value(hidden_states)
450+
# view 成多头格式: (batch,heads,seq_len,d_k)
451+
query_layer = self.transpose_for_scores(mixed_query_layer)
452+
key_layer = self.transpose_for_scores(mixed_key_layer)
453+
value_layer = self.transpose_for_scores(mixed_value_layer)
454+
455+
# Take the dot product between "query" and "key" to get the raw attention scores.
456+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (batch,heads,d_k,seq_len)
457+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
458+
if attention_mask is not None:
459+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
460+
attention_scores = attention_scores + attention_mask
461+
462+
# Normalize the attention scores to probabilities.
463+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
464+
465+
# This is actually dropping out entire tokens to attend to, which might
466+
# seem a bit unusual, but is taken from the original Transformer paper.
467+
attention_probs = self.dropout(attention_probs)
468+
469+
context_layer = torch.matmul(attention_probs, value_layer)
470+
471+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
472+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
473+
context_layer = context_layer.view(*new_context_layer_shape) # 合并头结果
474+
return context_layer
475+
```
476+
477+
#### BertSelfOutput
478+
479+
![BertSelfOutput计算流程图](图解BERT/15.png)
480+
481+
```python
482+
class BertSelfOutput(nn.Module):
483+
def __init__(self, config):
484+
super(BertSelfOutput, self).__init__()
485+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
486+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
487+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
488+
489+
# 残差链接 + 层归一化
490+
def forward(self, hidden_states, input_tensor):
491+
hidden_states = self.dense(hidden_states)
492+
hidden_states = self.dropout(hidden_states)
493+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
494+
return hidden_states
495+
```
496+
497+
#### BertAttention
498+
499+
![BertAttention计算流程图](图解BERT/16.png)
500+
501+
```python
502+
class BertAttention(nn.Module):
503+
def __init__(self, config):
504+
super(BertAttention, self).__init__()
505+
self.self = BertSelfAttention(config)
506+
self.output = BertSelfOutput(config)
507+
508+
def forward(self, input_tensor, attention_mask=None):
509+
self_outputs = self.self(input_tensor, attention_mask) # 多头自注意力机制
510+
attention_output = self.output(self_outputs, input_tensor)
511+
return attention_output
512+
```

src/LLM/图解BERT/10.png

2.03 KB
Loading

src/LLM/图解BERT/11.png

2.14 KB
Loading

src/LLM/图解BERT/12.png

1.26 KB
Loading

src/LLM/图解BERT/13.png

71 Bytes
Loading

src/LLM/图解BERT/14.png

388 KB
Loading

src/LLM/图解BERT/15.png

569 KB
Loading

src/LLM/图解BERT/16.png

666 KB
Loading

src/LLM/图解BERT/9.png

6.59 KB
Loading

0 commit comments

Comments
 (0)