@@ -135,22 +135,64 @@ class Blip2Qformer(Blip2Base):
135
135
)
136
136
...
137
137
```
138
+
139
+ > 以上代码注释中统一用B代替image_batch和text_batch,以及seq_len和hidden_size也是同样处理手段,大家注意区分。
140
+
138
141
为了训练好Q-Former,第一阶段设计了三个训练目标,分别如下:
139
142
140
143
1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
141
144
142
145
> 目的: Image representation 与 Text representation,以最大化互信息
146
+ >
143
147
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148
+ >
144
149
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150
+ >
145
151
> ![ Uni-modal Self-attention Mask] ( 庖丁解牛BLIP2/4.png )
146
152
147
153
image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,选择最大值作为这个图文对的相似度 :
148
154
149
155
![ similarity score] ( 庖丁解牛BLIP2/5.png )
150
156
157
+ 如何计算loss的: “in-batch negatives”,该方法正是CLIP在VLP领域发扬光大的。以下引用CLIP论文图做说明:
151
158
159
+ ![ in-batch negatives] ( 庖丁解牛BLIP2/6.png )
152
160
153
-
161
+ ``` python
162
+ # ##============== Image-text Contrastive ===================###
163
+ # 计算每个query token 和 text_feat 的相似度 , 得到相似度矩阵 (B,B,seq_len)
164
+ # image_feats (B,seq_len,hidden_size) 变为 (B,1,seq_len,hidden_size)
165
+ # text_feat (B,hidden_size) 变为 (B,hidden_size,1)
166
+ sim_q2t = torch.matmul(
167
+ image_feats.unsqueeze(1 ), text_feat.unsqueeze(- 1 )
168
+ ).squeeze()
169
+
170
+ # image-text similarity: aggregate across all query tokens
171
+ # 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
172
+ sim_i2t, _ = sim_q2t.max(- 1 )
173
+ sim_i2t = sim_i2t / self .temp
174
+
175
+ # 反过来计算text_feat 和 每个query token的相似度 , 得到相似度矩阵 (B,B,seq_len)
176
+ # image_feats 维度变为 (B,hidden_size,seq_len)
177
+ # text_feat (B,hidden_size) 变为 (B,1,1,hidden_size)
178
+ sim_t2q = torch.matmul(
179
+ text_feat.unsqueeze(1 ).unsqueeze(1 ), image_feats.permute(0 , 2 , 1 )
180
+ ).squeeze()
181
+
182
+ # text-image similarity: aggregate across all query tokens
183
+ # 保留和text_feat相似度最大的那个query token作为最后的相似度得分 , 维度为 (B,B)
184
+ sim_t2i, _ = sim_t2q.max(- 1 )
185
+ sim_t2i = sim_t2i / self .temp
186
+
187
+ # 生成比标签
188
+ targets = torch.arange(image.size(0 ), device = image.device)
189
+
190
+ # 计算 图文对比 Loss
191
+ loss_itc = (
192
+ # sim_i2t 形状是 (B, B),每一行表示一张图像和所有文本之间的相似度。
193
+ F.cross_entropy(sim_i2t, targets, label_smoothing = 0.1 ) + F.cross_entropy(sim_t2i, targets, label_smoothing = 0.1 )
194
+ ) / 2
195
+ ```
154
196
155
197
156
198
0 commit comments