@@ -302,8 +302,7 @@ class BertLayer(nn.Module):
302
302
self .output = BertOutput(config)
303
303
304
304
def forward (self , hidden_states , attention_mask = None ):
305
- attention_outputs = self .attention(hidden_states, attention_mask)
306
- attention_output = attention_outputs[0 ]
305
+ attention_output = self .attention(hidden_states, attention_mask)
307
306
intermediate_output = self .intermediate(attention_output)
308
307
layer_output = self .output(intermediate_output, attention_output)
309
308
return layer_output
@@ -415,4 +414,99 @@ class BertForSequenceClassification(BertPreTrainedModel):
415
414
outputs = (loss,) + outputs
416
415
417
416
return outputs # (loss), logits, (hidden_states), (attentions)
418
- ```
417
+ ```
418
+
419
+ ### BertAttention
420
+
421
+ #### BertSelfAttention
422
+
423
+ ![ 多头自注意力计算流程图] ( 图解BERT/14.png )
424
+
425
+ ``` python
426
+ class BertSelfAttention (nn .Module ):
427
+ def __init__ (self , config ):
428
+ super (BertSelfAttention, self ).__init__ ()
429
+ self .output_attentions = config.output_attentions
430
+
431
+ self .num_attention_heads = config.num_attention_heads
432
+ self .attention_head_size = int (config.hidden_size / config.num_attention_heads)
433
+ self .all_head_size = self .num_attention_heads * self .attention_head_size
434
+
435
+ self .query = nn.Linear(config.hidden_size, self .all_head_size)
436
+ self .key = nn.Linear(config.hidden_size, self .all_head_size)
437
+ self .value = nn.Linear(config.hidden_size, self .all_head_size)
438
+
439
+ self .dropout = nn.Dropout(config.attention_probs_dropout_prob)
440
+
441
+ def transpose_for_scores (self , x ):
442
+ new_x_shape = x.size()[:- 1 ] + (self .num_attention_heads, self .attention_head_size)
443
+ x = x.view(* new_x_shape)
444
+ return x.permute(0 , 2 , 1 , 3 )
445
+
446
+ def forward (self , hidden_states , attention_mask = None , head_mask = None ):
447
+ mixed_query_layer = self .query(hidden_states)
448
+ mixed_key_layer = self .key(hidden_states)
449
+ mixed_value_layer = self .value(hidden_states)
450
+ # view 成多头格式: (batch,heads,seq_len,d_k)
451
+ query_layer = self .transpose_for_scores(mixed_query_layer)
452
+ key_layer = self .transpose_for_scores(mixed_key_layer)
453
+ value_layer = self .transpose_for_scores(mixed_value_layer)
454
+
455
+ # Take the dot product between "query" and "key" to get the raw attention scores.
456
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(- 1 , - 2 )) # (batch,heads,d_k,seq_len)
457
+ attention_scores = attention_scores / math.sqrt(self .attention_head_size)
458
+ if attention_mask is not None :
459
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
460
+ attention_scores = attention_scores + attention_mask
461
+
462
+ # Normalize the attention scores to probabilities.
463
+ attention_probs = nn.Softmax(dim = - 1 )(attention_scores)
464
+
465
+ # This is actually dropping out entire tokens to attend to, which might
466
+ # seem a bit unusual, but is taken from the original Transformer paper.
467
+ attention_probs = self .dropout(attention_probs)
468
+
469
+ context_layer = torch.matmul(attention_probs, value_layer)
470
+
471
+ context_layer = context_layer.permute(0 , 2 , 1 , 3 ).contiguous()
472
+ new_context_layer_shape = context_layer.size()[:- 2 ] + (self .all_head_size,)
473
+ context_layer = context_layer.view(* new_context_layer_shape) # 合并头结果
474
+ return context_layer
475
+ ```
476
+
477
+ #### BertSelfOutput
478
+
479
+ ![ BertSelfOutput计算流程图] ( 图解BERT/15.png )
480
+
481
+ ``` python
482
+ class BertSelfOutput (nn .Module ):
483
+ def __init__ (self , config ):
484
+ super (BertSelfOutput, self ).__init__ ()
485
+ self .dense = nn.Linear(config.hidden_size, config.hidden_size)
486
+ self .LayerNorm = BertLayerNorm(config.hidden_size, eps = config.layer_norm_eps)
487
+ self .dropout = nn.Dropout(config.hidden_dropout_prob)
488
+
489
+ # 残差链接 + 层归一化
490
+ def forward (self , hidden_states , input_tensor ):
491
+ hidden_states = self .dense(hidden_states)
492
+ hidden_states = self .dropout(hidden_states)
493
+ hidden_states = self .LayerNorm(hidden_states + input_tensor)
494
+ return hidden_states
495
+ ```
496
+
497
+ #### BertAttention
498
+
499
+ ![ BertAttention计算流程图] ( 图解BERT/16.png )
500
+
501
+ ``` python
502
+ class BertAttention (nn .Module ):
503
+ def __init__ (self , config ):
504
+ super (BertAttention, self ).__init__ ()
505
+ self .self = BertSelfAttention(config)
506
+ self .output = BertSelfOutput(config)
507
+
508
+ def forward (self , input_tensor , attention_mask = None ):
509
+ self_outputs = self .self(input_tensor, attention_mask) # 多头自注意力机制
510
+ attention_output = self .output(self_outputs, input_tensor)
511
+ return attention_output
512
+ ```
0 commit comments