Skip to content

Commit 831f754

Browse files
committed
updates
1 parent 036c146 commit 831f754

File tree

8 files changed

+212
-1
lines changed

8 files changed

+212
-1
lines changed

src/LLM/图解BERT.md

Lines changed: 212 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,215 @@ features.append(
204204
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_lens,all_labels)
205205
```
206206

207-
## 模型架构
207+
## 模型架构
208+
209+
### DataLoader
210+
211+
```python
212+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
213+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,collate_fn=collate_fn)
214+
```
215+
DataLoader 设置的回调方法cllote_fn负责对返回的一个batch,在返回前进行预处理:
216+
217+
```python
218+
def collate_fn(batch):
219+
all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))
220+
max_len = max(all_lens).item() # 计算当前批次中所有序列的实际最大长度
221+
all_input_ids = all_input_ids[:, :max_len] # 按照本批次序列中最大长度进行截断: max_length --> max_len
222+
all_attention_mask = all_attention_mask[:, :max_len]
223+
all_token_type_ids = all_token_type_ids[:, :max_len]
224+
return all_input_ids, all_attention_mask, all_token_type_ids, all_labels
225+
```
226+
227+
### BertEmbeddings
228+
229+
![input embeddings = token embeddings + segmentation embeddings + position embeddings](图解BERT/7.png)
230+
231+
```python
232+
class BertEmbeddings(nn.Module):
233+
def __init__(self, config):
234+
super(BertEmbeddings, self).__init__()
235+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
236+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
237+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
238+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
239+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
240+
241+
def forward(self, input_ids, token_type_ids=None, position_ids=None):
242+
seq_length = input_ids.size(1)
243+
if position_ids is None:
244+
# 为当前批次中的每个序列样本生成一个位置序列: (1,2,3,4,5,...) , 构成一个位置序列矩阵
245+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
246+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
247+
if token_type_ids is None:
248+
token_type_ids = torch.zeros_like(input_ids)
249+
250+
words_embeddings = self.word_embeddings(input_ids)
251+
position_embeddings = self.position_embeddings(position_ids) # 位置编码为可学习的矩阵
252+
token_type_embeddings = self.token_type_embeddings(token_type_ids) # 让模型自己学会区分不同的句子
253+
254+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
255+
embeddings = self.LayerNorm(embeddings)
256+
embeddings = self.dropout(embeddings)
257+
return embeddings
258+
```
259+
260+
![嵌入向量生成过程图](图解BERT/8.png)
261+
262+
### BertEncoder
263+
264+
#### BertLayer
265+
266+
![BertLayer模型结构图](图解BERT/9.png)
267+
268+
```python
269+
class BertIntermediate(nn.Module):
270+
def __init__(self, config):
271+
super(BertIntermediate, self).__init__()
272+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) # (768,3072)
273+
# 激活函数 - GLEU
274+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
275+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
276+
else:
277+
self.intermediate_act_fn = config.hidden_act
278+
279+
def forward(self, hidden_states):
280+
hidden_states = self.dense(hidden_states)
281+
hidden_states = self.intermediate_act_fn(hidden_states) # 激活函数 - GLEU
282+
return hidden_states
283+
284+
class BertOutput(nn.Module):
285+
def __init__(self, config):
286+
super(BertOutput, self).__init__()
287+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) # (3072,768)
288+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
289+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
290+
291+
def forward(self, hidden_states, input_tensor):
292+
hidden_states = self.dense(hidden_states)
293+
hidden_states = self.dropout(hidden_states)
294+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
295+
return hidden_states
296+
297+
class BertLayer(nn.Module):
298+
def __init__(self, config):
299+
super(BertLayer, self).__init__()
300+
self.attention = BertAttention(config)
301+
self.intermediate = BertIntermediate(config)
302+
self.output = BertOutput(config)
303+
304+
def forward(self, hidden_states, attention_mask=None):
305+
attention_outputs = self.attention(hidden_states, attention_mask)
306+
attention_output = attention_outputs[0]
307+
intermediate_output = self.intermediate(attention_output)
308+
layer_output = self.output(intermediate_output, attention_output)
309+
return layer_output
310+
```
311+
312+
#### BertEncoder
313+
314+
![BertEncoder模型结构图](图解BERT/11.png)
315+
316+
```python
317+
class BertEncoder(nn.Module):
318+
def __init__(self, config):
319+
super(BertEncoder, self).__init__()
320+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
321+
322+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
323+
for i, layer_module in enumerate(self.layer):
324+
hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
325+
return hidden_states
326+
```
327+
328+
### BertPooler
329+
330+
![BertPooler模型结构图](图解BERT/10.png)
331+
332+
```python
333+
class BertPooler(nn.Module):
334+
def __init__(self, config):
335+
super(BertPooler, self).__init__()
336+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
337+
self.activation = nn.Tanh()
338+
339+
def forward(self, hidden_states):
340+
# We "pool" the model by simply taking the hidden state corresponding
341+
# to the first token.
342+
first_token_tensor = hidden_states[:, 0] # CLS Token Context Embeddings
343+
pooled_output = self.dense(first_token_tensor)
344+
pooled_output = self.activation(pooled_output)
345+
return pooled_output
346+
```
347+
### BertModel
348+
349+
350+
![BertModel模型结构图](图解BERT/12.png)
351+
352+
```python
353+
class BertModel(BertPreTrainedModel):
354+
def __init__(self, config):
355+
super(BertModel, self).__init__(config)
356+
self.embeddings = BertEmbeddings(config)
357+
self.encoder = BertEncoder(config)
358+
self.pooler = BertPooler(config)
359+
self.init_weights()
360+
361+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
362+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
363+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
364+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
365+
366+
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
367+
sequence_output = self.encoder(embedding_output,
368+
extended_attention_mask, # padding mask
369+
)
370+
pooled_output = self.pooler(sequence_output)
371+
372+
outputs = (sequence_output, pooled_output,)
373+
return outputs
374+
```
375+
376+
### BertForSequenceClassification
377+
378+
![BertForSequenceClassification模型结构图](图解BERT/13.png)
379+
380+
```python
381+
class BertForSequenceClassification(BertPreTrainedModel):
382+
def __init__(self, config):
383+
super(BertForSequenceClassification, self).__init__(config)
384+
self.num_labels = config.num_labels
385+
self.bert = BertModel(config)
386+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
387+
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
388+
389+
self.init_weights()
390+
391+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
392+
position_ids=None, head_mask=None, labels=None):
393+
394+
outputs = self.bert(input_ids,
395+
attention_mask=attention_mask, # padding mask
396+
token_type_ids=token_type_ids,
397+
position_ids=position_ids,
398+
head_mask=head_mask) # None ?
399+
400+
pooled_output = outputs[1] # 对于分类任务来说,只需要去除CLS Token用于分类任务即可
401+
402+
pooled_output = self.dropout(pooled_output)
403+
logits = self.classifier(pooled_output)
404+
405+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
406+
407+
if labels is not None:
408+
if self.num_labels == 1:
409+
# We are doing regression
410+
loss_fct = MSELoss()
411+
loss = loss_fct(logits.view(-1), labels.view(-1))
412+
else:
413+
loss_fct = CrossEntropyLoss()
414+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
415+
outputs = (loss,) + outputs
416+
417+
return outputs # (loss), logits, (hidden_states), (attentions)
418+
```

src/LLM/图解BERT/10.png

335 KB
Loading

src/LLM/图解BERT/11.png

227 KB
Loading

src/LLM/图解BERT/12.png

850 KB
Loading

src/LLM/图解BERT/13.png

1.21 MB
Loading

src/LLM/图解BERT/7.png

62.9 KB
Loading

src/LLM/图解BERT/8.png

298 KB
Loading

src/LLM/图解BERT/9.png

189 KB
Loading

0 commit comments

Comments
 (0)