@@ -75,4 +75,76 @@ Q-Former 由两个transformer模块组成,输入包含三部分:
75
75
76
76
3 . Input Text
77
77
78
-
78
+ Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,** 以便 Queries 可以学习到如何更好地结合文本提取图片信息** 。
79
+
80
+ 对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
81
+
82
+ - Q:learned queries
83
+ - K:input text
84
+ - V:image embeddings from Image Encoder
85
+
86
+ ``` python
87
+ class Blip2Qformer (Blip2Base ):
88
+ ...
89
+
90
+ def forward (self , samples ):
91
+ image = samples[" image" ]
92
+ text = samples[" text_input" ]
93
+ # (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
94
+ image_embeds = self .ln_vision(self .visual_encoder(image))
95
+ image_atts = torch.ones(image_embeds.size()[:- 1 ], dtype = torch.long).to(
96
+ image.device
97
+ )
98
+ # (1,32,768) --> (2,32,768) --> 共享内存
99
+ query_tokens = self .query_tokens.expand(image_embeds.shape[0 ], - 1 , - 1 )
100
+
101
+ query_output = self .Qformer.bert(
102
+ query_embeds = query_tokens,
103
+ encoder_hidden_states = image_embeds,
104
+ encoder_attention_mask = image_atts,
105
+ use_cache = True ,
106
+ return_dict = True ,
107
+ )
108
+ # BertEncoder 的 squence_output
109
+ image_feats = F.normalize(
110
+ self .vision_proj(query_output.last_hidden_state), dim = - 1
111
+ )
112
+
113
+ text_tokens = self .tokenizer(
114
+ text,
115
+ padding = " max_length" ,
116
+ truncation = True ,
117
+ max_length = self .max_txt_len,
118
+ return_tensors = " pt" ,
119
+ ).to(image.device)
120
+ text_output = self .Qformer.bert(
121
+ text_tokens.input_ids,
122
+ attention_mask = text_tokens.attention_mask, # padding mask
123
+ return_dict = True ,
124
+ )
125
+ text_feat = F.normalize( # 取CLS TOKEN ?
126
+ self .text_proj(text_output.last_hidden_state[:, 0 , :]), dim = - 1
127
+ )
128
+
129
+ # ##============== Image-text Contrastive ===================###
130
+ ...
131
+ loss_itc = (
132
+ F.cross_entropy(sim_i2t, targets, label_smoothing = 0.1 )
133
+ + F.cross_entropy(sim_t2i, targets, label_smoothing = 0.1 )
134
+ ) / 2
135
+
136
+ # ##============== Image-text Matching ===================###
137
+ ...
138
+ loss_itm = F.cross_entropy(logits, itm_labels)
139
+
140
+ # #================= Image Captioning ========================##
141
+ ...
142
+ loss_lm = lm_output.loss
143
+
144
+ return BlipOutput(
145
+ loss = loss_itc + loss_itm + loss_lm,
146
+ loss_itc = loss_itc,
147
+ loss_itm = loss_itm,
148
+ loss_lm = loss_lm,
149
+ )
150
+ ```
0 commit comments