@@ -142,11 +142,11 @@ class Blip2Qformer(Blip2Base):
142
142
143
143
#### 1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
144
144
145
- > 目的: Image representation 与 Text representation,以最大化互信息
145
+ > - 目的: Image representation 与 Text representation,以最大化互信息
146
146
>
147
- > 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
147
+ > - 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148
148
>
149
- > Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
149
+ > - Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150
150
>
151
151
> ![ Uni-modal Self-attention Mask] ( 庖丁解牛BLIP2/4.png )
152
152
@@ -195,18 +195,146 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
195
195
```
196
196
#### 2、Image-Text Matching (ITM Loss,二分类task)
197
197
198
- > 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
198
+ > - 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
199
199
>
200
- > 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
200
+ > - 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
201
201
>
202
- > Queries 和Text都能和所有的tokens 做attention
202
+ > - Queries 和Text都能和所有的tokens 做attention
203
+ >
203
204
> ![ Bi-directional Self-attention Mask] ( 庖丁解牛BLIP2/7.png )
204
205
205
206
206
207
每个output query embedding送到二分类器中,得到一个logit;所有logits的平均作为最终的matching score:
207
208
208
209
![ matching score] ( 庖丁解牛BLIP2/8.png )
209
210
211
+ ``` python
212
+ # ##============== Image-text Matching ===================###
213
+ text_input_ids_world = text_tokens.input_ids
214
+ text_attention_mask_world = text_tokens.attention_mask
215
+ image_embeds_world = image_embeds
216
+
217
+ with torch.no_grad():
218
+ # bs (batch size) , diag_indices = [0,1,2,...,bs-1]
219
+ diag_indices = torch.arange(bs, device = sim_t2i.device)
220
+ # 把相似度矩阵对角线元素置为负无穷大,以避免模型将匹配图文对挑选为负样本
221
+ # (0,0) , (1,1) ... (bs-1,bs-1) 位置处设置为 -10000
222
+ sim_t2i[diag_indices, diag_indices] = - 10000
223
+ sim_i2t[diag_indices, diag_indices] = - 10000
224
+
225
+ weights_t2i = F.softmax(sim_t2i, dim = 1 )
226
+ weights_i2t = F.softmax(sim_i2t, dim = 1 )
227
+
228
+ # 为每个文本选择一个负样本图像
229
+ image_embeds_neg = []
230
+ for b in range (bs):
231
+ neg_idx = torch.multinomial(weights_t2i[b], 1 ).item()
232
+ image_embeds_neg.append(image_embeds_world[neg_idx])
233
+ image_embeds_neg = torch.stack(image_embeds_neg, dim = 0 )
234
+
235
+ # 为每个图像选择一个负样本文本
236
+ text_ids_neg = []
237
+ text_atts_neg = []
238
+ for b in range (bs):
239
+ neg_idx = torch.multinomial(weights_i2t[b], 1 ).item()
240
+ text_ids_neg.append(text_input_ids_world[neg_idx])
241
+ text_atts_neg.append(text_attention_mask_world[neg_idx])
242
+ text_ids_neg = torch.stack(text_ids_neg, dim = 0 )
243
+ text_atts_neg = torch.stack(text_atts_neg, dim = 0 )
244
+
245
+ # 构建输入文本列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len)
246
+ text_ids_all = torch.cat(
247
+ [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim = 0
248
+ )
249
+ text_atts_all = torch.cat(
250
+ [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
251
+ dim = 0 ,
252
+ )
253
+
254
+ # 构建query tokens列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
255
+ query_tokens_itm = self .query_tokens.expand(text_ids_all.shape[0 ], - 1 , - 1 )
256
+ query_atts_itm = torch.ones(query_tokens_itm.size()[:- 1 ], dtype = torch.long).to(
257
+ image.device
258
+ )
259
+ # 构建query和text的padding mask ,维度为 (3*bs,seq_len)
260
+ attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim = 1 )
261
+
262
+ # 构建输入图像列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
263
+ image_embeds_all = torch.cat(
264
+ [image_embeds, image_embeds_neg, image_embeds], dim = 0
265
+ )
266
+ image_atts_all = torch.ones(image_embeds_all.size()[:- 1 ], dtype = torch.long).to(
267
+ image.device
268
+ )
269
+
270
+ # 1. 将输入文本转换为嵌入列表后和query tokens 在seq_len维度上拼接起来,维度为 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size)
271
+ # 2. 将文本和query tokens拼接得到的结果和图像嵌入进行cross attention计算,编码后得到输出的结果
272
+ output_itm = self .Qformer.bert(
273
+ text_ids_all,
274
+ query_embeds = query_tokens_itm,
275
+ attention_mask = attention_mask_all,
276
+ encoder_hidden_states = image_embeds_all,
277
+ encoder_attention_mask = image_atts_all,
278
+ return_dict = True ,
279
+ )
280
+
281
+ # 取 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size) 中 query tokens部分的结果,维度为 (3*bs,query_tokens_seq_len,hidden_size)
282
+ vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1 ), :]
283
+ # 把query tokens部分的每个位置都映射到2维匹配空间,维度为 (3*bs,query_tokens_seq_len,2)
284
+ vl_output = self .itm_head(vl_embeddings)
285
+ # 取每个位置的平均作为最终的匹配得分,维度为 (3*bs,2)
286
+ logits = vl_output.mean(dim = 1 )
287
+
288
+ # 构建匹配标签: [正样本batch=1,负样本batch1=0,负样本batch2=0] ,维度为 (3*bs)
289
+ itm_labels = torch.cat(
290
+ [torch.ones(bs, dtype = torch.long), torch.zeros(2 * bs, dtype = torch.long)],
291
+ dim = 0 ,
292
+ ).to(image.device)
293
+ # 计算交叉熵损失
294
+ loss_itm = F.cross_entropy(logits, itm_labels)
295
+ ```
296
+ 当文本和query tokens同时输入BertModel时,BertEmbeddings会将text embeddings和query tokens的embeddings在seq_len维度上拼接起来。
297
+
298
+ ``` python
299
+ class BertEmbeddings (nn .Module ):
300
+ ...
301
+ def forward (
302
+ self ,
303
+ input_ids = None ,
304
+ position_ids = None ,
305
+ query_embeds = None ,
306
+ past_key_values_length = 0 ,
307
+ ):
308
+ # 计算序列长度
309
+ if input_ids is not None :
310
+ seq_length = input_ids.size()[1 ]
311
+ else :
312
+ seq_length = 0
313
+
314
+ # 如果未提供位置id,则自动生成
315
+ if position_ids is None :
316
+ position_ids = self .position_ids[
317
+ :, past_key_values_length : seq_length + past_key_values_length
318
+ ].clone()
319
+
320
+ # 词嵌入与位置嵌入相加,若有query_embeds则拼接
321
+ if input_ids is not None :
322
+ embeddings = self .word_embeddings(input_ids)
323
+ if self .position_embedding_type == " absolute" :
324
+ position_embeddings = self .position_embeddings(position_ids)
325
+ embeddings = embeddings + position_embeddings
326
+
327
+ if query_embeds is not None :
328
+ embeddings = torch.cat((query_embeds, embeddings), dim = 1 )
329
+ else :
330
+ embeddings = query_embeds
331
+
332
+ embeddings = self .LayerNorm(embeddings)
333
+ embeddings = self .dropout(embeddings)
334
+ return embeddings
335
+ ```
336
+
337
+
210
338
211
339
BertLayer 核心代码实现如下:
212
340
0 commit comments