File tree Expand file tree Collapse file tree 1 file changed +42
-0
lines changed Expand file tree Collapse file tree 1 file changed +42
-0
lines changed Original file line number Diff line number Diff line change @@ -293,6 +293,48 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
293
293
# 计算交叉熵损失
294
294
loss_itm = F.cross_entropy(logits, itm_labels)
295
295
```
296
+ 当文本和query tokens同时输入BertModel时,BertEmbeddings会将text embeddings和query tokens的embeddings在seq_len维度上拼接起来。
297
+
298
+ ``` python
299
+ class BertEmbeddings (nn .Module ):
300
+ ...
301
+ def forward (
302
+ self ,
303
+ input_ids = None ,
304
+ position_ids = None ,
305
+ query_embeds = None ,
306
+ past_key_values_length = 0 ,
307
+ ):
308
+ # 计算序列长度
309
+ if input_ids is not None :
310
+ seq_length = input_ids.size()[1 ]
311
+ else :
312
+ seq_length = 0
313
+
314
+ # 如果未提供位置id,则自动生成
315
+ if position_ids is None :
316
+ position_ids = self .position_ids[
317
+ :, past_key_values_length : seq_length + past_key_values_length
318
+ ].clone()
319
+
320
+ # 词嵌入与位置嵌入相加,若有query_embeds则拼接
321
+ if input_ids is not None :
322
+ embeddings = self .word_embeddings(input_ids)
323
+ if self .position_embedding_type == " absolute" :
324
+ position_embeddings = self .position_embeddings(position_ids)
325
+ embeddings = embeddings + position_embeddings
326
+
327
+ if query_embeds is not None :
328
+ embeddings = torch.cat((query_embeds, embeddings), dim = 1 )
329
+ else :
330
+ embeddings = query_embeds
331
+
332
+ embeddings = self .LayerNorm(embeddings)
333
+ embeddings = self .dropout(embeddings)
334
+ return embeddings
335
+ ```
336
+
337
+
296
338
297
339
BertLayer 核心代码实现如下:
298
340
You can’t perform that action at this time.
0 commit comments