Skip to content

Commit 252298d

Browse files
committed
updates
1 parent 5f9e727 commit 252298d

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,48 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
293293
# 计算交叉熵损失
294294
loss_itm = F.cross_entropy(logits, itm_labels)
295295
```
296+
当文本和query tokens同时输入BertModel时,BertEmbeddings会将text embeddings和query tokens的embeddings在seq_len维度上拼接起来。
297+
298+
```python
299+
class BertEmbeddings(nn.Module):
300+
...
301+
def forward(
302+
self,
303+
input_ids=None,
304+
position_ids=None,
305+
query_embeds=None,
306+
past_key_values_length=0,
307+
):
308+
# 计算序列长度
309+
if input_ids is not None:
310+
seq_length = input_ids.size()[1]
311+
else:
312+
seq_length = 0
313+
314+
# 如果未提供位置id,则自动生成
315+
if position_ids is None:
316+
position_ids = self.position_ids[
317+
:, past_key_values_length : seq_length + past_key_values_length
318+
].clone()
319+
320+
# 词嵌入与位置嵌入相加,若有query_embeds则拼接
321+
if input_ids is not None:
322+
embeddings = self.word_embeddings(input_ids)
323+
if self.position_embedding_type == "absolute":
324+
position_embeddings = self.position_embeddings(position_ids)
325+
embeddings = embeddings + position_embeddings
326+
327+
if query_embeds is not None:
328+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
329+
else:
330+
embeddings = query_embeds
331+
332+
embeddings = self.LayerNorm(embeddings)
333+
embeddings = self.dropout(embeddings)
334+
return embeddings
335+
```
336+
337+
296338

297339
BertLayer 核心代码实现如下:
298340

0 commit comments

Comments
 (0)