Skip to content

Commit 763a69f

Browse files
committed
updates
1 parent 08815d8 commit 763a69f

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,76 @@ Q-Former 由两个transformer模块组成,输入包含三部分:
7575
7676
3. Input Text
7777

78-
78+
Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,**以便 Queries 可以学习到如何更好地结合文本提取图片信息**
79+
80+
对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
81+
82+
- Q:learned queries
83+
- K:input text
84+
- V:image embeddings from Image Encoder
85+
86+
```python
87+
class Blip2Qformer(Blip2Base):
88+
...
89+
90+
def forward(self, samples):
91+
image = samples["image"]
92+
text = samples["text_input"]
93+
# (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
94+
image_embeds = self.ln_vision(self.visual_encoder(image))
95+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
96+
image.device
97+
)
98+
# (1,32,768) --> (2,32,768) --> 共享内存
99+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
100+
101+
query_output = self.Qformer.bert(
102+
query_embeds=query_tokens,
103+
encoder_hidden_states=image_embeds,
104+
encoder_attention_mask=image_atts,
105+
use_cache=True,
106+
return_dict=True,
107+
)
108+
# BertEncoder 的 squence_output
109+
image_feats = F.normalize(
110+
self.vision_proj(query_output.last_hidden_state), dim=-1
111+
)
112+
113+
text_tokens = self.tokenizer(
114+
text,
115+
padding="max_length",
116+
truncation=True,
117+
max_length=self.max_txt_len,
118+
return_tensors="pt",
119+
).to(image.device)
120+
text_output = self.Qformer.bert(
121+
text_tokens.input_ids,
122+
attention_mask=text_tokens.attention_mask, # padding mask
123+
return_dict=True,
124+
)
125+
text_feat = F.normalize( # 取CLS TOKEN ?
126+
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
127+
)
128+
129+
###============== Image-text Contrastive ===================###
130+
...
131+
loss_itc = (
132+
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
133+
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
134+
) / 2
135+
136+
###============== Image-text Matching ===================###
137+
...
138+
loss_itm = F.cross_entropy(logits, itm_labels)
139+
140+
##================= Image Captioning ========================##
141+
...
142+
loss_lm = lm_output.loss
143+
144+
return BlipOutput(
145+
loss=loss_itc + loss_itm + loss_lm,
146+
loss_itc=loss_itc,
147+
loss_itm=loss_itm,
148+
loss_lm=loss_lm,
149+
)
150+
```

0 commit comments

Comments
 (0)