@@ -51,10 +51,100 @@ BLIP-2 基于 BLIP 架构,利用已有的ViT 和 LLM(均冻结)+ 一个的
51
51
52
52
BLIP-2 框架按照 Two-Stage 策略预训练轻量级查询 Transformer 以弥合模态差距。
53
53
54
- Stage 1: 不同模态数据的提取与融合。 Stage 2: 把数据转换成LLM能识别的格式。
54
+ Stage 1: 不同模态数据的提取与融合。
55
55
56
- ![ Two-Stage流程] ( 庖丁解牛BLIP2/2.png )
57
-
58
- 从冻结的Image Encoder引到Vision-Language表征学习。 从冻结的LLM引到Vision-Language生成学习,实现Zero Shot图文生成。
56
+ Stage 2: 把数据转换成LLM能识别的格式。
59
57
58
+ ![ Two-Stage流程] ( 庖丁解牛BLIP2/2.png )
60
59
60
+ 从冻结的Image Encoder引到Vision-Language表征学习。
61
+
62
+ 从冻结的LLM引到Vision-Language生成学习,实现Zero Shot图文生成。
63
+
64
+ ### Stage 1: Representation Learning (表征学习)
65
+
66
+ ![ tage 1: Representation Learning (表征学习)] ( 庖丁解牛BLIP2/3.png )
67
+
68
+ Q-Former 由两个transformer模块组成,输入包含三部分:
69
+
70
+ 1 . 冻结参数的Image Encoder提取的图像embeddings
71
+ 2 . Learned Queries
72
+
73
+ > - Queries是一组可学习的embeddings,是第一个transformer模块的input,可认为是模型参数一部分
74
+ > - 推理时,Queries被用来从image encoder输出的embeddings里提取与input text最相关的视觉信息
75
+
76
+ 3 . Input Text
77
+
78
+ Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,** 以便 Queries 可以学习到如何更好地结合文本提取图片信息** 。
79
+
80
+ 对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
81
+
82
+ - Q:learned queries
83
+ - K:input text
84
+ - V:image embeddings from Image Encoder
85
+
86
+ ``` python
87
+ class Blip2Qformer (Blip2Base ):
88
+ ...
89
+
90
+ def forward (self , samples ):
91
+ image = samples[" image" ]
92
+ text = samples[" text_input" ]
93
+ # (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
94
+ image_embeds = self .ln_vision(self .visual_encoder(image))
95
+ image_atts = torch.ones(image_embeds.size()[:- 1 ], dtype = torch.long).to(
96
+ image.device
97
+ )
98
+ # (1,32,768) --> (2,32,768) --> 共享内存
99
+ query_tokens = self .query_tokens.expand(image_embeds.shape[0 ], - 1 , - 1 )
100
+
101
+ query_output = self .Qformer.bert(
102
+ query_embeds = query_tokens,
103
+ encoder_hidden_states = image_embeds,
104
+ encoder_attention_mask = image_atts,
105
+ use_cache = True ,
106
+ return_dict = True ,
107
+ )
108
+ # BertEncoder 的 squence_output
109
+ image_feats = F.normalize(
110
+ self .vision_proj(query_output.last_hidden_state), dim = - 1
111
+ )
112
+
113
+ text_tokens = self .tokenizer(
114
+ text,
115
+ padding = " max_length" ,
116
+ truncation = True ,
117
+ max_length = self .max_txt_len,
118
+ return_tensors = " pt" ,
119
+ ).to(image.device)
120
+ text_output = self .Qformer.bert(
121
+ text_tokens.input_ids,
122
+ attention_mask = text_tokens.attention_mask, # padding mask
123
+ return_dict = True ,
124
+ )
125
+ text_feat = F.normalize( # 取CLS TOKEN ?
126
+ self .text_proj(text_output.last_hidden_state[:, 0 , :]), dim = - 1
127
+ )
128
+
129
+ # ##============== Image-text Contrastive ===================###
130
+ ...
131
+ loss_itc = (
132
+ F.cross_entropy(sim_i2t, targets, label_smoothing = 0.1 )
133
+ + F.cross_entropy(sim_t2i, targets, label_smoothing = 0.1 )
134
+ ) / 2
135
+
136
+ # ##============== Image-text Matching ===================###
137
+ ...
138
+ loss_itm = F.cross_entropy(logits, itm_labels)
139
+
140
+ # #================= Image Captioning ========================##
141
+ ...
142
+ loss_lm = lm_output.loss
143
+
144
+ return BlipOutput(
145
+ loss = loss_itc + loss_itm + loss_lm,
146
+ loss_itc = loss_itc,
147
+ loss_itm = loss_itm,
148
+ loss_lm = loss_lm,
149
+ )
150
+ ```
0 commit comments