Skip to content

Commit fd35d81

Browse files
committed
updates
1 parent 763a69f commit fd35d81

File tree

3 files changed

+249
-26
lines changed

3 files changed

+249
-26
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 249 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,41 +75,48 @@ Q-Former 由两个transformer模块组成,输入包含三部分:
7575
7676
3. Input Text
7777

78-
Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,**以便 Queries 可以学习到如何更好地结合文本提取图片信息**
78+
**Stage 1** 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,**以便 Queries 可以学习到如何更好地结合文本提取图片信息**
7979

8080
对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
8181

8282
- Q:learned queries
8383
- K:input text
8484
- V:image embeddings from Image Encoder
8585

86+
Blip2Qformer核心代码实现如下:
87+
88+
1. 利用 query tokens 从 image embeddings 中提取与 text 最相关的视觉信息
89+
2. 将输入的 input text 进行编码 , 然后使用第一个CLS Token 作为 input text representation
90+
8691
```python
8792
class Blip2Qformer(Blip2Base):
8893
...
89-
94+
9095
def forward(self, samples):
91-
image = samples["image"]
92-
text = samples["text_input"]
93-
# (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
96+
image = samples["image"] # (B,C,H,W)
97+
text = samples["text_input"] # (B,seq_len)
98+
# frozen vit 将图片编码成 (B, seq_len, hidden_size)
9499
image_embeds = self.ln_vision(self.visual_encoder(image))
100+
# 构建padding mask标注哪些image token是有效的 (B,seq_len)
95101
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
96102
image.device
97103
)
98-
# (1,32,768) --> (2,32,768) --> 共享内存
104+
# 初始化query tokens (B,seq_len,hidden_size)
99105
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
100-
106+
# query tokens 从 image embeddings 中提取与 text 最相关的视觉信息
107+
# query_output (B,seq_len,hidden_size)
101108
query_output = self.Qformer.bert(
102109
query_embeds=query_tokens,
103110
encoder_hidden_states=image_embeds,
104111
encoder_attention_mask=image_atts,
105112
use_cache=True,
106113
return_dict=True,
107114
)
108-
# BertEncoder 的 squence_output
109115
image_feats = F.normalize(
110116
self.vision_proj(query_output.last_hidden_state), dim=-1
111117
)
112-
118+
119+
# 将input text 进行编码,维度为 (B,seq_len,hidden_size)
113120
text_tokens = self.tokenizer(
114121
text,
115122
padding="max_length",
@@ -122,29 +129,245 @@ class Blip2Qformer(Blip2Base):
122129
attention_mask=text_tokens.attention_mask, # padding mask
123130
return_dict=True,
124131
)
125-
text_feat = F.normalize( # 取CLS TOKEN ?
132+
# 取第一个cls token作为input text representation,维度为 (B,hidden_size)
133+
text_feat = F.normalize(
126134
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
127135
)
136+
...
137+
```
138+
为了训练好Q-Former,第一阶段设计了三个训练目标,分别如下:
139+
140+
1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
141+
142+
> 目的: Image representation 与 Text representation,以最大化互信息
143+
> 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
144+
> Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
145+
> ![Uni-modal Self-attention Mask](庖丁解牛BLIP2/4.png)
146+
147+
image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,选择最大值作为这个图文对的相似度 :
148+
149+
![similarity score](庖丁解牛BLIP2/5.png)
150+
151+
128152

129-
###============== Image-text Contrastive ===================###
130-
...
131-
loss_itc = (
132-
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
133-
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
134-
) / 2
135153

136-
###============== Image-text Matching ===================###
137-
...
138-
loss_itm = F.cross_entropy(logits, itm_labels)
139154

140-
##================= Image Captioning ========================##
155+
156+
157+
BertLayer 核心代码实现如下:
158+
159+
```python
160+
class BertLayer(nn.Module):
161+
def __init__(self, config, layer_num):
162+
super().__init__()
141163
...
142-
loss_lm = lm_output.loss
164+
self.attention = BertAttention(config)
165+
# 每个几层,添加一个cross attention
166+
if (
167+
self.config.add_cross_attention
168+
and layer_num % self.config.cross_attention_freq == 0
169+
):
170+
self.crossattention = BertAttention(
171+
config, is_cross_attention=self.config.add_cross_attention
172+
)
173+
self.has_cross_attention = True
174+
else:
175+
self.has_cross_attention = False
176+
self.intermediate = BertIntermediate(config)
177+
self.output = BertOutput(config)
178+
179+
self.intermediate_query = BertIntermediate(config)
180+
self.output_query = BertOutput(config)
181+
182+
def forward(
183+
self,
184+
hidden_states, # query tokens
185+
attention_mask=None, # query token padding mask
186+
head_mask=None,
187+
encoder_hidden_states=None, # image tokens
188+
encoder_attention_mask=None, # image padding mask
189+
past_key_value=None,
190+
output_attentions=False,
191+
query_length=0,
192+
):
193+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
194+
self_attn_past_key_value = (
195+
past_key_value[:2] if past_key_value is not None else None
196+
)
197+
# 如果
198+
self_attention_outputs = self.attention(
199+
hidden_states,
200+
attention_mask,
201+
head_mask,
202+
output_attentions=output_attentions,
203+
past_key_value=self_attn_past_key_value,
204+
)
205+
attention_output = self_attention_outputs[0]
206+
outputs = self_attention_outputs[1:-1]
207+
208+
present_key_value = self_attention_outputs[-1]
209+
210+
if query_length > 0: # query tokens nums = 32
211+
query_attention_output = attention_output[:, :query_length, :]
212+
# query tokens 经过 attention 后得到 context states
213+
if self.has_cross_attention:
214+
# q (context states) , k 和 v 来自 images encoder
215+
# q(b,32,h) * k(b,seq_len,h).t = (b,32,seq_len)
216+
# score(b,32,seq_len) * v(b,seq_len,h) = (b,32,h)
217+
cross_attention_outputs = self.crossattention(
218+
query_attention_output,
219+
attention_mask,
220+
head_mask,
221+
encoder_hidden_states,
222+
encoder_attention_mask,
223+
output_attentions=output_attentions,
224+
)
225+
query_attention_output = cross_attention_outputs[0]
226+
outputs = (
227+
outputs + cross_attention_outputs[1:-1]
228+
) # add cross attentions if we output attention weights
229+
# 分块并行处理
230+
layer_output = apply_chunking_to_forward(
231+
self.feed_forward_chunk_query,
232+
self.chunk_size_feed_forward,
233+
self.seq_len_dim,
234+
query_attention_output,
235+
)
236+
if attention_output.shape[1] > query_length:
237+
layer_output_text = apply_chunking_to_forward(
238+
self.feed_forward_chunk,
239+
self.chunk_size_feed_forward,
240+
self.seq_len_dim,
241+
attention_output[:, query_length:, :],
242+
)
243+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
244+
else:
245+
layer_output = apply_chunking_to_forward(
246+
self.feed_forward_chunk,
247+
self.chunk_size_feed_forward,
248+
self.seq_len_dim,
249+
attention_output,
250+
)
251+
outputs = (layer_output,) + outputs
252+
253+
outputs = outputs + (present_key_value,)
254+
255+
return outputs
256+
257+
def feed_forward_chunk(self, attention_output):
258+
intermediate_output = self.intermediate(attention_output)
259+
layer_output = self.output(intermediate_output, attention_output)
260+
return layer_output
261+
262+
def feed_forward_chunk_query(self, attention_output):
263+
intermediate_output = self.intermediate_query(attention_output)
264+
layer_output = self.output_query(intermediate_output, attention_output)
265+
return layer_output
266+
```
143267

144-
return BlipOutput(
145-
loss=loss_itc + loss_itm + loss_lm,
146-
loss_itc=loss_itc,
147-
loss_itm=loss_itm,
148-
loss_lm=loss_lm,
268+
```python
269+
class BertAttention(nn.Module):
270+
def __init__(self, config, is_cross_attention=False):
271+
super().__init__()
272+
self.self = BertSelfAttention(config, is_cross_attention)
273+
self.output = BertSelfOutput(config)
274+
275+
def forward(
276+
self,
277+
hidden_states,
278+
attention_mask=None,
279+
head_mask=None,
280+
encoder_hidden_states=None,
281+
encoder_attention_mask=None,
282+
past_key_value=None,
283+
output_attentions=False,
284+
):
285+
self_outputs = self.self(
286+
hidden_states,
287+
attention_mask,
288+
head_mask,
289+
encoder_hidden_states,
290+
encoder_attention_mask,
291+
past_key_value,
292+
output_attentions,
293+
) # (context_layer, attention_probs)
294+
attention_output = self.output(self_outputs[0], hidden_states)
295+
296+
outputs = (attention_output,) + self_outputs[
297+
1:
298+
] # add attentions if we output them
299+
return outputs
300+
```
301+
302+
```python
303+
class BertSelfAttention(nn.Module):
304+
...
305+
def transpose_for_scores(self, x):
306+
# 调整张量形状以便多头注意力计算
307+
new_x_shape = x.size()[:-1] + (
308+
self.num_attention_heads,
309+
self.attention_head_size,
310+
)
311+
x = x.view(*new_x_shape)
312+
return x.permute(0, 2, 1, 3)
313+
314+
def forward(
315+
self,
316+
hidden_states,
317+
attention_mask=None,
318+
head_mask=None,
319+
encoder_hidden_states=None,
320+
encoder_attention_mask=None,
321+
past_key_value=None,
322+
output_attentions=False,
323+
):
324+
325+
# 判断是否为交叉注意力
326+
is_cross_attention = encoder_hidden_states is not None
327+
328+
if is_cross_attention: # key and value come from image
329+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
330+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
331+
attention_mask = encoder_attention_mask
332+
elif past_key_value is not None:
333+
key_layer = self.transpose_for_scores(self.key(hidden_states))
334+
value_layer = self.transpose_for_scores(self.value(hidden_states))
335+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
336+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
337+
else:
338+
key_layer = self.transpose_for_scores(self.key(hidden_states))
339+
value_layer = self.transpose_for_scores(self.value(hidden_states))
340+
341+
mixed_query_layer = self.query(hidden_states)
342+
343+
query_layer = self.transpose_for_scores(mixed_query_layer)
344+
345+
past_key_value = (key_layer, value_layer)
346+
347+
# 计算注意力分数
348+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
349+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
350+
if attention_mask is not None:
351+
# 应用注意力掩码
352+
attention_scores = attention_scores + attention_mask
353+
354+
# softmax归一化得到注意力概率
355+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
356+
357+
# dropout防止过拟合
358+
attention_probs_dropped = self.dropout(attention_probs)
359+
360+
# 计算上下文表示
361+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
362+
363+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
364+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
365+
context_layer = context_layer.view(*new_context_layer_shape)
366+
367+
outputs = (
368+
(context_layer, attention_probs) if output_attentions else (context_layer,)
149369
)
370+
371+
outputs = outputs + (past_key_value,)
372+
return outputs
150373
```

src/MMLLM/庖丁解牛BLIP2/4.png

66.6 KB
Loading

src/MMLLM/庖丁解牛BLIP2/5.png

168 KB
Loading

0 commit comments

Comments
 (0)