Skip to content

Commit 951af14

Browse files
committed
updates
1 parent fd35d81 commit 951af14

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,52 @@ class Blip2Qformer(Blip2Base):
140140
1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
141141

142142
> 目的: Image representation 与 Text representation,以最大化互信息
143+
>
143144
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
145+
>
144146
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
147+
>
145148
> ![Uni-modal Self-attention Mask](庖丁解牛BLIP2/4.png)
146149
147150
image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,选择最大值作为这个图文对的相似度 :
148151

149152
![similarity score](庖丁解牛BLIP2/5.png)
150153

151154

152-
153-
154-
155+
```python
156+
###============== Image-text Contrastive ===================###
157+
# 计算每个query token 和 text_feat 的相似度 , 得到相似度矩阵 (B,B,seq_len)
158+
# image_feats (B,seq_len,hidden_size) 变为 (B,1,seq_len,hidden_size)
159+
# text_feat (B,hidden_size) 变为 (B,hidden_size,1)
160+
sim_q2t = torch.matmul(
161+
image_feats.unsqueeze(1), text_feat.unsqueeze(-1)
162+
).squeeze()
163+
164+
# image-text similarity: aggregate across all query tokens
165+
# 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
166+
sim_i2t, _ = sim_q2t.max(-1)
167+
sim_i2t = sim_i2t / self.temp
168+
169+
# 反过来计算text_feat 和 每个query token的相似度 , 得到相似度矩阵 (B,B,seq_len)
170+
# image_feats 维度变为 (B,hidden_size,seq_len)
171+
# text_feat (B,hidden_size) 变为 (B,1,1,hidden_size)
172+
sim_t2q = torch.matmul(
173+
text_feat.unsqueeze(1).unsqueeze(1), image_feats.permute(0, 2, 1)
174+
).squeeze()
175+
176+
# text-image similarity: aggregate across all query tokens
177+
# 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
178+
sim_t2i, _ = sim_t2q.max(-1)
179+
sim_t2i = sim_t2i / self.temp
180+
181+
# 生成比标签
182+
targets = torch.arange(image.size(0), device=image.device)
183+
184+
# 计算 图文对比 Loss
185+
loss_itc = (
186+
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
187+
) / 2
188+
```
155189

156190

157191
BertLayer 核心代码实现如下:

0 commit comments

Comments
 (0)