Skip to content

Commit 8d6e42d

Browse files
authored
Merge pull request #8 from BinaryOracle/master
Master
2 parents 074ca45 + 252298d commit 8d6e42d

File tree

1 file changed

+134
-6
lines changed

1 file changed

+134
-6
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 134 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ class Blip2Qformer(Blip2Base):
142142

143143
#### 1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
144144

145-
> 目的: Image representation 与 Text representation,以最大化互信息
145+
> - 目的: Image representation 与 Text representation,以最大化互信息
146146
>
147-
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
147+
> - 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
148148
>
149-
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
149+
> - Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
150150
>
151151
> ![Uni-modal Self-attention Mask](庖丁解牛BLIP2/4.png)
152152
@@ -195,18 +195,146 @@ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,
195195
```
196196
#### 2、Image-Text Matching (ITM Loss,二分类task)
197197

198-
> 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
198+
> - 目的:通过学习image-text pair是否match,以细粒度对齐 Image representation 与 Text representation
199199
>
200-
> 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
200+
> - 自注意力掩码策略: Bi-directional Self-attention Mask(双向自注意力)
201201
>
202-
> Queries 和Text都能和所有的tokens 做attention
202+
> - Queries 和Text都能和所有的tokens 做attention
203+
>
203204
> ![Bi-directional Self-attention Mask](庖丁解牛BLIP2/7.png)
204205
205206

206207
每个output query embedding送到二分类器中,得到一个logit;所有logits的平均作为最终的matching score:
207208

208209
![matching score](庖丁解牛BLIP2/8.png)
209210

211+
```python
212+
###============== Image-text Matching ===================###
213+
text_input_ids_world = text_tokens.input_ids
214+
text_attention_mask_world = text_tokens.attention_mask
215+
image_embeds_world = image_embeds
216+
217+
with torch.no_grad():
218+
# bs (batch size) , diag_indices = [0,1,2,...,bs-1]
219+
diag_indices = torch.arange(bs, device=sim_t2i.device)
220+
# 把相似度矩阵对角线元素置为负无穷大,以避免模型将匹配图文对挑选为负样本
221+
# (0,0) , (1,1) ... (bs-1,bs-1) 位置处设置为 -10000
222+
sim_t2i[diag_indices, diag_indices] = -10000
223+
sim_i2t[diag_indices, diag_indices] = -10000
224+
225+
weights_t2i = F.softmax(sim_t2i, dim=1)
226+
weights_i2t = F.softmax(sim_i2t, dim=1)
227+
228+
# 为每个文本选择一个负样本图像
229+
image_embeds_neg = []
230+
for b in range(bs):
231+
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
232+
image_embeds_neg.append(image_embeds_world[neg_idx])
233+
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
234+
235+
# 为每个图像选择一个负样本文本
236+
text_ids_neg = []
237+
text_atts_neg = []
238+
for b in range(bs):
239+
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
240+
text_ids_neg.append(text_input_ids_world[neg_idx])
241+
text_atts_neg.append(text_attention_mask_world[neg_idx])
242+
text_ids_neg = torch.stack(text_ids_neg, dim=0)
243+
text_atts_neg = torch.stack(text_atts_neg, dim=0)
244+
245+
# 构建输入文本列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len)
246+
text_ids_all = torch.cat(
247+
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
248+
)
249+
text_atts_all = torch.cat(
250+
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
251+
dim=0,
252+
)
253+
254+
# 构建query tokens列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
255+
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
256+
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
257+
image.device
258+
)
259+
# 构建query和text的padding mask ,维度为 (3*bs,seq_len)
260+
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
261+
262+
# 构建输入图像列表: [正样本batch,负样本batch1,负样本batch2] ,维度为 (3*bs,seq_len,hidden_size)
263+
image_embeds_all = torch.cat(
264+
[image_embeds, image_embeds_neg, image_embeds], dim=0
265+
)
266+
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
267+
image.device
268+
)
269+
270+
# 1. 将输入文本转换为嵌入列表后和query tokens 在seq_len维度上拼接起来,维度为 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size)
271+
# 2. 将文本和query tokens拼接得到的结果和图像嵌入进行cross attention计算,编码后得到输出的结果
272+
output_itm = self.Qformer.bert(
273+
text_ids_all,
274+
query_embeds=query_tokens_itm,
275+
attention_mask=attention_mask_all,
276+
encoder_hidden_states=image_embeds_all,
277+
encoder_attention_mask=image_atts_all,
278+
return_dict=True,
279+
)
280+
281+
# 取 (3*bs,text_seq_len + query_tokens_seq_len,hidden_size) 中 query tokens部分的结果,维度为 (3*bs,query_tokens_seq_len,hidden_size)
282+
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
283+
# 把query tokens部分的每个位置都映射到2维匹配空间,维度为 (3*bs,query_tokens_seq_len,2)
284+
vl_output = self.itm_head(vl_embeddings)
285+
# 取每个位置的平均作为最终的匹配得分,维度为 (3*bs,2)
286+
logits = vl_output.mean(dim=1)
287+
288+
# 构建匹配标签: [正样本batch=1,负样本batch1=0,负样本batch2=0] ,维度为 (3*bs)
289+
itm_labels = torch.cat(
290+
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
291+
dim=0,
292+
).to(image.device)
293+
# 计算交叉熵损失
294+
loss_itm = F.cross_entropy(logits, itm_labels)
295+
```
296+
当文本和query tokens同时输入BertModel时,BertEmbeddings会将text embeddings和query tokens的embeddings在seq_len维度上拼接起来。
297+
298+
```python
299+
class BertEmbeddings(nn.Module):
300+
...
301+
def forward(
302+
self,
303+
input_ids=None,
304+
position_ids=None,
305+
query_embeds=None,
306+
past_key_values_length=0,
307+
):
308+
# 计算序列长度
309+
if input_ids is not None:
310+
seq_length = input_ids.size()[1]
311+
else:
312+
seq_length = 0
313+
314+
# 如果未提供位置id,则自动生成
315+
if position_ids is None:
316+
position_ids = self.position_ids[
317+
:, past_key_values_length : seq_length + past_key_values_length
318+
].clone()
319+
320+
# 词嵌入与位置嵌入相加,若有query_embeds则拼接
321+
if input_ids is not None:
322+
embeddings = self.word_embeddings(input_ids)
323+
if self.position_embedding_type == "absolute":
324+
position_embeddings = self.position_embeddings(position_ids)
325+
embeddings = embeddings + position_embeddings
326+
327+
if query_embeds is not None:
328+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
329+
else:
330+
embeddings = query_embeds
331+
332+
embeddings = self.LayerNorm(embeddings)
333+
embeddings = self.dropout(embeddings)
334+
return embeddings
335+
```
336+
337+
210338

211339
BertLayer 核心代码实现如下:
212340

0 commit comments

Comments
 (0)