Skip to content

Commit f9fbe99

Browse files
authored
Merge pull request #6 from BinaryOracle/master
Master
2 parents 7a0b727 + 875b796 commit f9fbe99

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,22 +135,64 @@ class Blip2Qformer(Blip2Base):
135135
)
136136
...
137137
```
138+
139+
> 以上代码注释中统一用B代替image_batch和text_batch,以及seq_len和hidden_size也是同样处理手段,大家注意区分。
140+
138141
为了训练好Q-Former,第一阶段设计了三个训练目标,分别如下:
139142

140143
1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
141144

142145
> 目的: Image representation 与 Text representation,以最大化互信息
146+
>
143147
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148+
>
144149
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150+
>
145151
> ![Uni-modal Self-attention Mask](庖丁解牛BLIP2/4.png)
146152
147153
image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,选择最大值作为这个图文对的相似度 :
148154

149155
![similarity score](庖丁解牛BLIP2/5.png)
150156

157+
如何计算loss的: “in-batch negatives”,该方法正是CLIP在VLP领域发扬光大的。以下引用CLIP论文图做说明:
151158

159+
![in-batch negatives](庖丁解牛BLIP2/6.png)
152160

153-
161+
```python
162+
###============== Image-text Contrastive ===================###
163+
# 计算每个query token 和 text_feat 的相似度 , 得到相似度矩阵 (B,B,seq_len)
164+
# image_feats (B,seq_len,hidden_size) 变为 (B,1,seq_len,hidden_size)
165+
# text_feat (B,hidden_size) 变为 (B,hidden_size,1)
166+
sim_q2t = torch.matmul(
167+
image_feats.unsqueeze(1), text_feat.unsqueeze(-1)
168+
).squeeze()
169+
170+
# image-text similarity: aggregate across all query tokens
171+
# 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
172+
sim_i2t, _ = sim_q2t.max(-1)
173+
sim_i2t = sim_i2t / self.temp
174+
175+
# 反过来计算text_feat 和 每个query token的相似度 , 得到相似度矩阵 (B,B,seq_len)
176+
# image_feats 维度变为 (B,hidden_size,seq_len)
177+
# text_feat (B,hidden_size) 变为 (B,1,1,hidden_size)
178+
sim_t2q = torch.matmul(
179+
text_feat.unsqueeze(1).unsqueeze(1), image_feats.permute(0, 2, 1)
180+
).squeeze()
181+
182+
# text-image similarity: aggregate across all query tokens
183+
# 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
184+
sim_t2i, _ = sim_t2q.max(-1)
185+
sim_t2i = sim_t2i / self.temp
186+
187+
# 生成比标签
188+
targets = torch.arange(image.size(0), device=image.device)
189+
190+
# 计算 图文对比 Loss
191+
loss_itc = (
192+
# sim_i2t 形状是 (B, B),每一行表示一张图像和所有文本之间的相似度。
193+
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
194+
) / 2
195+
```
154196

155197

156198

src/MMLLM/庖丁解牛BLIP2/6.png

107 KB
Loading

0 commit comments

Comments
 (0)