Skip to content

Commit 5f9e727

Browse files
committed
updates
1 parent 8c81606 commit 5f9e727

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -210,65 +210,65 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
210210

211211
```python
212212
###============== Image-text Matching ===================###
213-
214213
text_input_ids_world = text_tokens.input_ids
215214
text_attention_mask_world = text_tokens.attention_mask
216215
image_embeds_world = image_embeds
217216

218217
with torch.no_grad():
219-
if "image_id" in samples.keys():
220-
mask = torch.eq(image_ids, image_ids.t())
221-
sim_t2i.masked_fill_(mask, -10000)
222-
sim_i2t.masked_fill_(mask, -10000)
223-
else:
224-
# 在单卡中,sim_t2i[b, b] 是自己这一项,屏蔽掉防止作弊
225-
diag_indices = torch.arange(bs, device=sim_t2i.device)
226-
sim_t2i[diag_indices, diag_indices] = -10000
227-
sim_i2t[diag_indices, diag_indices] = -10000
228-
218+
# bs (batch size) , diag_indices = [0,1,2,...,bs-1]
219+
diag_indices = torch.arange(bs, device=sim_t2i.device)
220+
# 把相似度矩阵对角线元素置为负无穷大,以避免模型将匹配图文对挑选为负样本
221+
# (0,0) , (1,1) ... (bs-1,bs-1) 位置处设置为 -10000
222+
sim_t2i[diag_indices, diag_indices] = -10000
223+
sim_i2t[diag_indices, diag_indices] = -10000
224+
229225
weights_t2i = F.softmax(sim_t2i, dim=1)
230226
weights_i2t = F.softmax(sim_i2t, dim=1)
231227

232-
# select a negative image for each text
228+
# 为每个文本选择一个负样本图像
233229
image_embeds_neg = []
234230
for b in range(bs):
235231
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
236232
image_embeds_neg.append(image_embeds_world[neg_idx])
237233
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
238234

239-
# select a negative text for each image
235+
# 为每个图像选择一个负样本文本
240236
text_ids_neg = []
241237
text_atts_neg = []
242238
for b in range(bs):
243239
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
244240
text_ids_neg.append(text_input_ids_world[neg_idx])
245241
text_atts_neg.append(text_attention_mask_world[neg_idx])
246-
247242
text_ids_neg = torch.stack(text_ids_neg, dim=0)
248243
text_atts_neg = torch.stack(text_atts_neg, dim=0)
249244

250-
# 构建 ITM 输入:正样本 + 负样本
245+
# 构建输入文本列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len)
251246
text_ids_all = torch.cat(
252247
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
253248
)
254249
text_atts_all = torch.cat(
255250
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
256251
dim=0,
257252
)
258-
253+
254+
# 构建query tokens列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
259255
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
260256
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
261257
image.device
262258
)
259+
# 构建query和text的padding mask ,维度为 (3*bs,seq_len)
263260
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
264-
261+
262+
# 构建输入图像列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
265263
image_embeds_all = torch.cat(
266264
[image_embeds, image_embeds_neg, image_embeds], dim=0
267265
)
268266
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
269267
image.device
270268
)
271-
269+
270+
# 1. 将输入文本转换为嵌入列表后和query tokens 在seq_len维度上拼接起来,维度为 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size)
271+
# 2. 将文本和query tokens拼接得到的结果和图像嵌入进行cross attention计算,编码后得到输出的结果
272272
output_itm = self.Qformer.bert(
273273
text_ids_all,
274274
query_embeds=query_tokens_itm,
@@ -278,14 +278,19 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
278278
return_dict=True,
279279
)
280280

281+
# 取 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size) 中 query tokens部分的结果,维度为 (3*bs,query_tokens_seq_len,hidden_size)
281282
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
283+
# 把query tokens部分的每个位置都映射到2维匹配空间,维度为 (3*bs,query_tokens_seq_len,2)
282284
vl_output = self.itm_head(vl_embeddings)
285+
# 取每个位置的平均作为最终的匹配得分,维度为 (3*bs,2)
283286
logits = vl_output.mean(dim=1)
284287

288+
# 构建匹配标签: [正样本batch=1,负样本batch1=0,负样本batch2=0] ,维度为 (3*bs)
285289
itm_labels = torch.cat(
286290
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
287291
dim=0,
288292
).to(image.device)
293+
# 计算交叉熵损失
289294
loss_itm = F.cross_entropy(logits, itm_labels)
290295
```
291296

0 commit comments

Comments
 (0)