Skip to content

Commit 8c81606

Browse files
committed
updates
1 parent b43d930 commit 8c81606

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ class Blip2Qformer(Blip2Base):
142142

143143
#### 1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
144144

145-
> 目的: Image representation 与 Text representation,以最大化互信息
145+
> - 目的: Image representation 与 Text representation,以最大化互信息
146146
>
147-
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
147+
> - 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148148
>
149-
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
149+
> - Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150150
>
151151
> ![Uni-modal Self-attention Mask](庖丁解牛BLIP2/4.png)
152152
@@ -195,18 +195,99 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
195195
```
196196
#### 2、Image-Text Matching (ITM Loss,二分类task)
197197

198-
> 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
198+
> - 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
199199
>
200-
> 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
200+
> - 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
201201
>
202-
> Queries 和Text都能和所有的tokens 做attention
202+
> - Queries 和Text都能和所有的tokens 做attention
203+
>
203204
> ![Bi-directional Self-attention Mask](庖丁解牛BLIP2/7.png)
204205
205206

206207
每个output query embedding送到二分类器中,得到一个logit;所有logits的平均作为最终的matching score:
207208

208209
![matching score](庖丁解牛BLIP2/8.png)
209210

211+
```python
212+
###============== Image-text Matching ===================###
213+
214+
text_input_ids_world = text_tokens.input_ids
215+
text_attention_mask_world = text_tokens.attention_mask
216+
image_embeds_world = image_embeds
217+
218+
with torch.no_grad():
219+
if "image_id" in samples.keys():
220+
mask = torch.eq(image_ids, image_ids.t())
221+
sim_t2i.masked_fill_(mask, -10000)
222+
sim_i2t.masked_fill_(mask, -10000)
223+
else:
224+
# 在单卡中,sim_t2i[b, b] 是自己这一项,屏蔽掉防止作弊
225+
diag_indices = torch.arange(bs, device=sim_t2i.device)
226+
sim_t2i[diag_indices, diag_indices] = -10000
227+
sim_i2t[diag_indices, diag_indices] = -10000
228+
229+
weights_t2i = F.softmax(sim_t2i, dim=1)
230+
weights_i2t = F.softmax(sim_i2t, dim=1)
231+
232+
# select a negative image for each text
233+
image_embeds_neg = []
234+
for b in range(bs):
235+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
236+
image_embeds_neg.append(image_embeds_world[neg_idx])
237+
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
238+
239+
# select a negative text for each image
240+
text_ids_neg = []
241+
text_atts_neg = []
242+
for b in range(bs):
243+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
244+
text_ids_neg.append(text_input_ids_world[neg_idx])
245+
text_atts_neg.append(text_attention_mask_world[neg_idx])
246+
247+
text_ids_neg = torch.stack(text_ids_neg, dim=0)
248+
text_atts_neg = torch.stack(text_atts_neg, dim=0)
249+
250+
# 构建 ITM 输入:正样本 + 负样本
251+
text_ids_all = torch.cat(
252+
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
253+
)
254+
text_atts_all = torch.cat(
255+
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
256+
dim=0,
257+
)
258+
259+
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
260+
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
261+
image.device
262+
)
263+
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
264+
265+
image_embeds_all = torch.cat(
266+
[image_embeds, image_embeds_neg, image_embeds], dim=0
267+
)
268+
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
269+
image.device
270+
)
271+
272+
output_itm = self.Qformer.bert(
273+
text_ids_all,
274+
query_embeds=query_tokens_itm,
275+
attention_mask=attention_mask_all,
276+
encoder_hidden_states=image_embeds_all,
277+
encoder_attention_mask=image_atts_all,
278+
return_dict=True,
279+
)
280+
281+
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
282+
vl_output = self.itm_head(vl_embeddings)
283+
logits = vl_output.mean(dim=1)
284+
285+
itm_labels = torch.cat(
286+
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
287+
dim=0,
288+
).to(image.device)
289+
loss_itm = F.cross_entropy(logits, itm_labels)
290+
```
210291

211292
BertLayer 核心代码实现如下:
212293

0 commit comments

Comments
 (0)