@@ -204,4 +204,215 @@ features.append(
204
204
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_lens,all_labels)
205
205
```
206
206
207
- ## 模型架构
207
+ ## 模型架构
208
+
209
+ ### DataLoader
210
+
211
+ ``` python
212
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == - 1 else DistributedSampler(train_dataset)
213
+ train_dataloader = DataLoader(train_dataset, sampler = train_sampler, batch_size = args.train_batch_size,collate_fn = collate_fn)
214
+ ```
215
+ DataLoader 设置的回调方法cllote_fn负责对返回的一个batch,在返回前进行预处理:
216
+
217
+ ``` python
218
+ def collate_fn (batch ):
219
+ all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map (torch.stack, zip (* batch))
220
+ max_len = max (all_lens).item() # 计算当前批次中所有序列的实际最大长度
221
+ all_input_ids = all_input_ids[:, :max_len] # 按照本批次序列中最大长度进行截断: max_length --> max_len
222
+ all_attention_mask = all_attention_mask[:, :max_len]
223
+ all_token_type_ids = all_token_type_ids[:, :max_len]
224
+ return all_input_ids, all_attention_mask, all_token_type_ids, all_labels
225
+ ```
226
+
227
+ ### BertEmbeddings
228
+
229
+ ![ input embeddings = token embeddings + segmentation embeddings + position embeddings] ( 图解BERT/7.png )
230
+
231
+ ``` python
232
+ class BertEmbeddings (nn .Module ):
233
+ def __init__ (self , config ):
234
+ super (BertEmbeddings, self ).__init__ ()
235
+ self .word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx = 0 )
236
+ self .position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
237
+ self .token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
238
+ self .LayerNorm = BertLayerNorm(config.hidden_size, eps = config.layer_norm_eps)
239
+ self .dropout = nn.Dropout(config.hidden_dropout_prob)
240
+
241
+ def forward (self , input_ids , token_type_ids = None , position_ids = None ):
242
+ seq_length = input_ids.size(1 )
243
+ if position_ids is None :
244
+ # 为当前批次中的每个序列样本生成一个位置序列: (1,2,3,4,5,...) , 构成一个位置序列矩阵
245
+ position_ids = torch.arange(seq_length, dtype = torch.long, device = input_ids.device)
246
+ position_ids = position_ids.unsqueeze(0 ).expand_as(input_ids)
247
+ if token_type_ids is None :
248
+ token_type_ids = torch.zeros_like(input_ids)
249
+
250
+ words_embeddings = self .word_embeddings(input_ids)
251
+ position_embeddings = self .position_embeddings(position_ids) # 位置编码为可学习的矩阵
252
+ token_type_embeddings = self .token_type_embeddings(token_type_ids) # 让模型自己学会区分不同的句子
253
+
254
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
255
+ embeddings = self .LayerNorm(embeddings)
256
+ embeddings = self .dropout(embeddings)
257
+ return embeddings
258
+ ```
259
+
260
+ ![ 嵌入向量生成过程图] ( 图解BERT/8.png )
261
+
262
+ ### BertEncoder
263
+
264
+ #### BertLayer
265
+
266
+ ![ BertLayer模型结构图] ( 图解BERT/9.png )
267
+
268
+ ``` python
269
+ class BertIntermediate (nn .Module ):
270
+ def __init__ (self , config ):
271
+ super (BertIntermediate, self ).__init__ ()
272
+ self .dense = nn.Linear(config.hidden_size, config.intermediate_size) # (768,3072)
273
+ # 激活函数 - GLEU
274
+ if isinstance (config.hidden_act, str ) or (sys.version_info[0 ] == 2 and isinstance (config.hidden_act, unicode )):
275
+ self .intermediate_act_fn = ACT2FN [config.hidden_act]
276
+ else :
277
+ self .intermediate_act_fn = config.hidden_act
278
+
279
+ def forward (self , hidden_states ):
280
+ hidden_states = self .dense(hidden_states)
281
+ hidden_states = self .intermediate_act_fn(hidden_states) # 激活函数 - GLEU
282
+ return hidden_states
283
+
284
+ class BertOutput (nn .Module ):
285
+ def __init__ (self , config ):
286
+ super (BertOutput, self ).__init__ ()
287
+ self .dense = nn.Linear(config.intermediate_size, config.hidden_size) # (3072,768)
288
+ self .LayerNorm = BertLayerNorm(config.hidden_size, eps = config.layer_norm_eps)
289
+ self .dropout = nn.Dropout(config.hidden_dropout_prob)
290
+
291
+ def forward (self , hidden_states , input_tensor ):
292
+ hidden_states = self .dense(hidden_states)
293
+ hidden_states = self .dropout(hidden_states)
294
+ hidden_states = self .LayerNorm(hidden_states + input_tensor)
295
+ return hidden_states
296
+
297
+ class BertLayer (nn .Module ):
298
+ def __init__ (self , config ):
299
+ super (BertLayer, self ).__init__ ()
300
+ self .attention = BertAttention(config)
301
+ self .intermediate = BertIntermediate(config)
302
+ self .output = BertOutput(config)
303
+
304
+ def forward (self , hidden_states , attention_mask = None ):
305
+ attention_outputs = self .attention(hidden_states, attention_mask)
306
+ attention_output = attention_outputs[0 ]
307
+ intermediate_output = self .intermediate(attention_output)
308
+ layer_output = self .output(intermediate_output, attention_output)
309
+ return layer_output
310
+ ```
311
+
312
+ #### BertEncoder
313
+
314
+ ![ BertEncoder模型结构图] ( 图解BERT/11.png )
315
+
316
+ ``` python
317
+ class BertEncoder (nn .Module ):
318
+ def __init__ (self , config ):
319
+ super (BertEncoder, self ).__init__ ()
320
+ self .layer = nn.ModuleList([BertLayer(config) for _ in range (config.num_hidden_layers)])
321
+
322
+ def forward (self , hidden_states , attention_mask = None , head_mask = None ):
323
+ for i, layer_module in enumerate (self .layer):
324
+ hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
325
+ return hidden_states
326
+ ```
327
+
328
+ ### BertPooler
329
+
330
+ ![ BertPooler模型结构图] ( 图解BERT/10.png )
331
+
332
+ ``` python
333
+ class BertPooler (nn .Module ):
334
+ def __init__ (self , config ):
335
+ super (BertPooler, self ).__init__ ()
336
+ self .dense = nn.Linear(config.hidden_size, config.hidden_size)
337
+ self .activation = nn.Tanh()
338
+
339
+ def forward (self , hidden_states ):
340
+ # We "pool" the model by simply taking the hidden state corresponding
341
+ # to the first token.
342
+ first_token_tensor = hidden_states[:, 0 ] # CLS Token Context Embeddings
343
+ pooled_output = self .dense(first_token_tensor)
344
+ pooled_output = self .activation(pooled_output)
345
+ return pooled_output
346
+ ```
347
+ ### BertModel
348
+
349
+
350
+ ![ BertModel模型结构图] ( 图解BERT/12.png )
351
+
352
+ ``` python
353
+ class BertModel (BertPreTrainedModel ):
354
+ def __init__ (self , config ):
355
+ super (BertModel, self ).__init__ (config)
356
+ self .embeddings = BertEmbeddings(config)
357
+ self .encoder = BertEncoder(config)
358
+ self .pooler = BertPooler(config)
359
+ self .init_weights()
360
+
361
+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ):
362
+ extended_attention_mask = attention_mask.unsqueeze(1 ).unsqueeze(2 )
363
+ extended_attention_mask = extended_attention_mask.to(dtype = next (self .parameters()).dtype) # fp16 compatibility
364
+ extended_attention_mask = (1.0 - extended_attention_mask) * - 10000.0
365
+
366
+ embedding_output = self .embeddings(input_ids, position_ids = position_ids, token_type_ids = token_type_ids)
367
+ sequence_output = self .encoder(embedding_output,
368
+ extended_attention_mask, # padding mask
369
+ )
370
+ pooled_output = self .pooler(sequence_output)
371
+
372
+ outputs = (sequence_output, pooled_output,)
373
+ return outputs
374
+ ```
375
+
376
+ ### BertForSequenceClassification
377
+
378
+ ![ BertForSequenceClassification模型结构图] ( 图解BERT/13.png )
379
+
380
+ ``` python
381
+ class BertForSequenceClassification (BertPreTrainedModel ):
382
+ def __init__ (self , config ):
383
+ super (BertForSequenceClassification, self ).__init__ (config)
384
+ self .num_labels = config.num_labels
385
+ self .bert = BertModel(config)
386
+ self .dropout = nn.Dropout(config.hidden_dropout_prob)
387
+ self .classifier = nn.Linear(config.hidden_size, self .config.num_labels)
388
+
389
+ self .init_weights()
390
+
391
+ def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
392
+ position_ids = None , head_mask = None , labels = None ):
393
+
394
+ outputs = self .bert(input_ids,
395
+ attention_mask = attention_mask, # padding mask
396
+ token_type_ids = token_type_ids,
397
+ position_ids = position_ids,
398
+ head_mask = head_mask) # None ?
399
+
400
+ pooled_output = outputs[1 ] # 对于分类任务来说,只需要去除CLS Token用于分类任务即可
401
+
402
+ pooled_output = self .dropout(pooled_output)
403
+ logits = self .classifier(pooled_output)
404
+
405
+ outputs = (logits,) + outputs[2 :] # add hidden states and attention if they are here
406
+
407
+ if labels is not None :
408
+ if self .num_labels == 1 :
409
+ # We are doing regression
410
+ loss_fct = MSELoss()
411
+ loss = loss_fct(logits.view(- 1 ), labels.view(- 1 ))
412
+ else :
413
+ loss_fct = CrossEntropyLoss()
414
+ loss = loss_fct(logits.view(- 1 , self .num_labels), labels.view(- 1 ))
415
+ outputs = (loss,) + outputs
416
+
417
+ return outputs # (loss), logits, (hidden_states), (attentions)
418
+ ```
0 commit comments