@@ -75,41 +75,48 @@ Q-Former 由两个transformer模块组成,输入包含三部分:
75
75
76
76
3 . Input Text
77
77
78
- Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,** 以便 Queries 可以学习到如何更好地结合文本提取图片信息** 。
78
+ ** Stage 1** 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,** 以便 Queries 可以学习到如何更好地结合文本提取图片信息** 。
79
79
80
80
对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
81
81
82
82
- Q:learned queries
83
83
- K:input text
84
84
- V:image embeddings from Image Encoder
85
85
86
+ Blip2Qformer核心代码实现如下:
87
+
88
+ 1 . 利用 query tokens 从 image embeddings 中提取与 text 最相关的视觉信息
89
+ 2 . 将输入的 input text 进行编码 , 然后使用第一个CLS Token 作为 input text representation
90
+
86
91
``` python
87
92
class Blip2Qformer (Blip2Base ):
88
93
...
89
-
94
+
90
95
def forward (self , samples ):
91
- image = samples[" image" ]
92
- text = samples[" text_input" ]
93
- # (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
96
+ image = samples[" image" ] # (B,C,H,W)
97
+ text = samples[" text_input" ] # (B,seq_len)
98
+ # frozen vit 将图片编码成 (B, seq_len, hidden_size)
94
99
image_embeds = self .ln_vision(self .visual_encoder(image))
100
+ # 构建padding mask标注哪些image token是有效的 (B,seq_len)
95
101
image_atts = torch.ones(image_embeds.size()[:- 1 ], dtype = torch.long).to(
96
102
image.device
97
103
)
98
- # (1,32,768) --> (2,32,768) --> 共享内存
104
+ # 初始化query tokens (B,seq_len,hidden_size)
99
105
query_tokens = self .query_tokens.expand(image_embeds.shape[0 ], - 1 , - 1 )
100
-
106
+ # query tokens 从 image embeddings 中提取与 text 最相关的视觉信息
107
+ # query_output (B,seq_len,hidden_size)
101
108
query_output = self .Qformer.bert(
102
109
query_embeds = query_tokens,
103
110
encoder_hidden_states = image_embeds,
104
111
encoder_attention_mask = image_atts,
105
112
use_cache = True ,
106
113
return_dict = True ,
107
114
)
108
- # BertEncoder 的 squence_output
109
115
image_feats = F.normalize(
110
116
self .vision_proj(query_output.last_hidden_state), dim = - 1
111
117
)
112
-
118
+
119
+ # 将input text 进行编码,维度为 (B,seq_len,hidden_size)
113
120
text_tokens = self .tokenizer(
114
121
text,
115
122
padding = " max_length" ,
@@ -122,29 +129,245 @@ class Blip2Qformer(Blip2Base):
122
129
attention_mask = text_tokens.attention_mask, # padding mask
123
130
return_dict = True ,
124
131
)
125
- text_feat = F.normalize( # 取CLS TOKEN ?
132
+ # 取第一个cls token作为input text representation,维度为 (B,hidden_size)
133
+ text_feat = F.normalize(
126
134
self .text_proj(text_output.last_hidden_state[:, 0 , :]), dim = - 1
127
135
)
136
+ ...
137
+ ```
138
+ 为了训练好Q-Former,第一阶段设计了三个训练目标,分别如下:
139
+
140
+ 1、Image-Text Contrastive Learning (ITC Loss, CLIP-like)
141
+
142
+ > 目的: Image representation 与 Text representation,以最大化互信息
143
+ > 自注意力掩码策略: Uni-modal Self-attention Mask(单模态自注意力)
144
+ > Queries 和 Text 仅能和自己的 tokens 做 attention(Query和Query、Text和Text)
145
+ > ![ Uni-modal Self-attention Mask] ( 庖丁解牛BLIP2/4.png )
146
+
147
+ image_feats 中每个 image_feat 与 text_feat 计算一个 similarity score ,选择最大值作为这个图文对的相似度 :
148
+
149
+ ![ similarity score] ( 庖丁解牛BLIP2/5.png )
150
+
151
+
128
152
129
- # ##============== Image-text Contrastive ===================###
130
- ...
131
- loss_itc = (
132
- F.cross_entropy(sim_i2t, targets, label_smoothing = 0.1 )
133
- + F.cross_entropy(sim_t2i, targets, label_smoothing = 0.1 )
134
- ) / 2
135
153
136
- # ##============== Image-text Matching ===================###
137
- ...
138
- loss_itm = F.cross_entropy(logits, itm_labels)
139
154
140
- # #================= Image Captioning ========================##
155
+
156
+
157
+ BertLayer 核心代码实现如下:
158
+
159
+ ``` python
160
+ class BertLayer (nn .Module ):
161
+ def __init__ (self , config , layer_num ):
162
+ super ().__init__ ()
141
163
...
142
- loss_lm = lm_output.loss
164
+ self .attention = BertAttention(config)
165
+ # 每个几层,添加一个cross attention
166
+ if (
167
+ self .config.add_cross_attention
168
+ and layer_num % self .config.cross_attention_freq == 0
169
+ ):
170
+ self .crossattention = BertAttention(
171
+ config, is_cross_attention = self .config.add_cross_attention
172
+ )
173
+ self .has_cross_attention = True
174
+ else :
175
+ self .has_cross_attention = False
176
+ self .intermediate = BertIntermediate(config)
177
+ self .output = BertOutput(config)
178
+
179
+ self .intermediate_query = BertIntermediate(config)
180
+ self .output_query = BertOutput(config)
181
+
182
+ def forward (
183
+ self ,
184
+ hidden_states , # query tokens
185
+ attention_mask = None , # query token padding mask
186
+ head_mask = None ,
187
+ encoder_hidden_states = None , # image tokens
188
+ encoder_attention_mask = None , # image padding mask
189
+ past_key_value = None ,
190
+ output_attentions = False ,
191
+ query_length = 0 ,
192
+ ):
193
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
194
+ self_attn_past_key_value = (
195
+ past_key_value[:2 ] if past_key_value is not None else None
196
+ )
197
+ # 如果
198
+ self_attention_outputs = self .attention(
199
+ hidden_states,
200
+ attention_mask,
201
+ head_mask,
202
+ output_attentions = output_attentions,
203
+ past_key_value = self_attn_past_key_value,
204
+ )
205
+ attention_output = self_attention_outputs[0 ]
206
+ outputs = self_attention_outputs[1 :- 1 ]
207
+
208
+ present_key_value = self_attention_outputs[- 1 ]
209
+
210
+ if query_length > 0 : # query tokens nums = 32
211
+ query_attention_output = attention_output[:, :query_length, :]
212
+ # query tokens 经过 attention 后得到 context states
213
+ if self .has_cross_attention:
214
+ # q (context states) , k 和 v 来自 images encoder
215
+ # q(b,32,h) * k(b,seq_len,h).t = (b,32,seq_len)
216
+ # score(b,32,seq_len) * v(b,seq_len,h) = (b,32,h)
217
+ cross_attention_outputs = self .crossattention(
218
+ query_attention_output,
219
+ attention_mask,
220
+ head_mask,
221
+ encoder_hidden_states,
222
+ encoder_attention_mask,
223
+ output_attentions = output_attentions,
224
+ )
225
+ query_attention_output = cross_attention_outputs[0 ]
226
+ outputs = (
227
+ outputs + cross_attention_outputs[1 :- 1 ]
228
+ ) # add cross attentions if we output attention weights
229
+ # 分块并行处理
230
+ layer_output = apply_chunking_to_forward(
231
+ self .feed_forward_chunk_query,
232
+ self .chunk_size_feed_forward,
233
+ self .seq_len_dim,
234
+ query_attention_output,
235
+ )
236
+ if attention_output.shape[1 ] > query_length:
237
+ layer_output_text = apply_chunking_to_forward(
238
+ self .feed_forward_chunk,
239
+ self .chunk_size_feed_forward,
240
+ self .seq_len_dim,
241
+ attention_output[:, query_length:, :],
242
+ )
243
+ layer_output = torch.cat([layer_output, layer_output_text], dim = 1 )
244
+ else :
245
+ layer_output = apply_chunking_to_forward(
246
+ self .feed_forward_chunk,
247
+ self .chunk_size_feed_forward,
248
+ self .seq_len_dim,
249
+ attention_output,
250
+ )
251
+ outputs = (layer_output,) + outputs
252
+
253
+ outputs = outputs + (present_key_value,)
254
+
255
+ return outputs
256
+
257
+ def feed_forward_chunk (self , attention_output ):
258
+ intermediate_output = self .intermediate(attention_output)
259
+ layer_output = self .output(intermediate_output, attention_output)
260
+ return layer_output
261
+
262
+ def feed_forward_chunk_query (self , attention_output ):
263
+ intermediate_output = self .intermediate_query(attention_output)
264
+ layer_output = self .output_query(intermediate_output, attention_output)
265
+ return layer_output
266
+ ```
143
267
144
- return BlipOutput(
145
- loss = loss_itc + loss_itm + loss_lm,
146
- loss_itc = loss_itc,
147
- loss_itm = loss_itm,
148
- loss_lm = loss_lm,
268
+ ``` python
269
+ class BertAttention (nn .Module ):
270
+ def __init__ (self , config , is_cross_attention = False ):
271
+ super ().__init__ ()
272
+ self .self = BertSelfAttention(config, is_cross_attention)
273
+ self .output = BertSelfOutput(config)
274
+
275
+ def forward (
276
+ self ,
277
+ hidden_states ,
278
+ attention_mask = None ,
279
+ head_mask = None ,
280
+ encoder_hidden_states = None ,
281
+ encoder_attention_mask = None ,
282
+ past_key_value = None ,
283
+ output_attentions = False ,
284
+ ):
285
+ self_outputs = self .self(
286
+ hidden_states,
287
+ attention_mask,
288
+ head_mask,
289
+ encoder_hidden_states,
290
+ encoder_attention_mask,
291
+ past_key_value,
292
+ output_attentions,
293
+ ) # (context_layer, attention_probs)
294
+ attention_output = self .output(self_outputs[0 ], hidden_states)
295
+
296
+ outputs = (attention_output,) + self_outputs[
297
+ 1 :
298
+ ] # add attentions if we output them
299
+ return outputs
300
+ ```
301
+
302
+ ``` python
303
+ class BertSelfAttention (nn .Module ):
304
+ ...
305
+ def transpose_for_scores (self , x ):
306
+ # 调整张量形状以便多头注意力计算
307
+ new_x_shape = x.size()[:- 1 ] + (
308
+ self .num_attention_heads,
309
+ self .attention_head_size,
310
+ )
311
+ x = x.view(* new_x_shape)
312
+ return x.permute(0 , 2 , 1 , 3 )
313
+
314
+ def forward (
315
+ self ,
316
+ hidden_states ,
317
+ attention_mask = None ,
318
+ head_mask = None ,
319
+ encoder_hidden_states = None ,
320
+ encoder_attention_mask = None ,
321
+ past_key_value = None ,
322
+ output_attentions = False ,
323
+ ):
324
+
325
+ # 判断是否为交叉注意力
326
+ is_cross_attention = encoder_hidden_states is not None
327
+
328
+ if is_cross_attention: # key and value come from image
329
+ key_layer = self .transpose_for_scores(self .key(encoder_hidden_states))
330
+ value_layer = self .transpose_for_scores(self .value(encoder_hidden_states))
331
+ attention_mask = encoder_attention_mask
332
+ elif past_key_value is not None :
333
+ key_layer = self .transpose_for_scores(self .key(hidden_states))
334
+ value_layer = self .transpose_for_scores(self .value(hidden_states))
335
+ key_layer = torch.cat([past_key_value[0 ], key_layer], dim = 2 )
336
+ value_layer = torch.cat([past_key_value[1 ], value_layer], dim = 2 )
337
+ else :
338
+ key_layer = self .transpose_for_scores(self .key(hidden_states))
339
+ value_layer = self .transpose_for_scores(self .value(hidden_states))
340
+
341
+ mixed_query_layer = self .query(hidden_states)
342
+
343
+ query_layer = self .transpose_for_scores(mixed_query_layer)
344
+
345
+ past_key_value = (key_layer, value_layer)
346
+
347
+ # 计算注意力分数
348
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(- 1 , - 2 ))
349
+ attention_scores = attention_scores / math.sqrt(self .attention_head_size)
350
+ if attention_mask is not None :
351
+ # 应用注意力掩码
352
+ attention_scores = attention_scores + attention_mask
353
+
354
+ # softmax归一化得到注意力概率
355
+ attention_probs = nn.Softmax(dim = - 1 )(attention_scores)
356
+
357
+ # dropout防止过拟合
358
+ attention_probs_dropped = self .dropout(attention_probs)
359
+
360
+ # 计算上下文表示
361
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
362
+
363
+ context_layer = context_layer.permute(0 , 2 , 1 , 3 ).contiguous()
364
+ new_context_layer_shape = context_layer.size()[:- 2 ] + (self .all_head_size,)
365
+ context_layer = context_layer.view(* new_context_layer_shape)
366
+
367
+ outputs = (
368
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
149
369
)
370
+
371
+ outputs = outputs + (past_key_value,)
372
+ return outputs
150
373
```
0 commit comments