@@ -142,11 +142,11 @@ class Blip2Qformer(Blip2Base):
142
142
143
143
#### 1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
144
144
145
- > 目的: Image representation 与 Text representation,以最大化互信息
145
+ > - 目的: Image representation 与 Text representation,以最大化互信息
146
146
>
147
- > 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
147
+ > - 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148
148
>
149
- > Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
149
+ > - Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150
150
>
151
151
> ![ Uni-modal Self-attention Mask] ( 庖丁解牛BLIP2/4.png )
152
152
@@ -195,18 +195,99 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
195
195
```
196
196
#### 2、Image-Text Matching (ITM Loss,二分类task)
197
197
198
- > 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
198
+ > - 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
199
199
>
200
- > 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
200
+ > - 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
201
201
>
202
- > Queries 和Text都能和所有的tokens 做attention
202
+ > - Queries 和Text都能和所有的tokens 做attention
203
+ >
203
204
> ![ Bi-directional Self-attention Mask] ( 庖丁解牛BLIP2/7.png )
204
205
205
206
206
207
每个output query embedding送到二分类器中,得到一个logit;所有logits的平均作为最终的matching score:
207
208
208
209
![ matching score] ( 庖丁解牛BLIP2/8.png )
209
210
211
+ ``` python
212
+ # ##============== Image-text Matching ===================###
213
+
214
+ text_input_ids_world = text_tokens.input_ids
215
+ text_attention_mask_world = text_tokens.attention_mask
216
+ image_embeds_world = image_embeds
217
+
218
+ with torch.no_grad():
219
+ if " image_id" in samples.keys():
220
+ mask = torch.eq(image_ids, image_ids.t())
221
+ sim_t2i.masked_fill_(mask, - 10000 )
222
+ sim_i2t.masked_fill_(mask, - 10000 )
223
+ else :
224
+ # 在单卡中,sim_t2i[b, b] 是自己这一项,屏蔽掉防止作弊
225
+ diag_indices = torch.arange(bs, device = sim_t2i.device)
226
+ sim_t2i[diag_indices, diag_indices] = - 10000
227
+ sim_i2t[diag_indices, diag_indices] = - 10000
228
+
229
+ weights_t2i = F.softmax(sim_t2i, dim = 1 )
230
+ weights_i2t = F.softmax(sim_i2t, dim = 1 )
231
+
232
+ # select a negative image for each text
233
+ image_embeds_neg = []
234
+ for b in range (bs):
235
+ neg_idx = torch.multinomial(weights_t2i[b], 1 ).item()
236
+ image_embeds_neg.append(image_embeds_world[neg_idx])
237
+ image_embeds_neg = torch.stack(image_embeds_neg, dim = 0 )
238
+
239
+ # select a negative text for each image
240
+ text_ids_neg = []
241
+ text_atts_neg = []
242
+ for b in range (bs):
243
+ neg_idx = torch.multinomial(weights_i2t[b], 1 ).item()
244
+ text_ids_neg.append(text_input_ids_world[neg_idx])
245
+ text_atts_neg.append(text_attention_mask_world[neg_idx])
246
+
247
+ text_ids_neg = torch.stack(text_ids_neg, dim = 0 )
248
+ text_atts_neg = torch.stack(text_atts_neg, dim = 0 )
249
+
250
+ # 构建 ITM 输入:正样本 + 负样本
251
+ text_ids_all = torch.cat(
252
+ [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim = 0
253
+ )
254
+ text_atts_all = torch.cat(
255
+ [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
256
+ dim = 0 ,
257
+ )
258
+
259
+ query_tokens_itm = self .query_tokens.expand(text_ids_all.shape[0 ], - 1 , - 1 )
260
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:- 1 ], dtype = torch.long).to(
261
+ image.device
262
+ )
263
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim = 1 )
264
+
265
+ image_embeds_all = torch.cat(
266
+ [image_embeds, image_embeds_neg, image_embeds], dim = 0
267
+ )
268
+ image_atts_all = torch.ones(image_embeds_all.size()[:- 1 ], dtype = torch.long).to(
269
+ image.device
270
+ )
271
+
272
+ output_itm = self .Qformer.bert(
273
+ text_ids_all,
274
+ query_embeds = query_tokens_itm,
275
+ attention_mask = attention_mask_all,
276
+ encoder_hidden_states = image_embeds_all,
277
+ encoder_attention_mask = image_atts_all,
278
+ return_dict = True ,
279
+ )
280
+
281
+ vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1 ), :]
282
+ vl_output = self .itm_head(vl_embeddings)
283
+ logits = vl_output.mean(dim = 1 )
284
+
285
+ itm_labels = torch.cat(
286
+ [torch.ones(bs, dtype = torch.long), torch.zeros(2 * bs, dtype = torch.long)],
287
+ dim = 0 ,
288
+ ).to(image.device)
289
+ loss_itm = F.cross_entropy(logits, itm_labels)
290
+ ```
210
291
211
292
BertLayer 核心代码实现如下:
212
293
0 commit comments