Skip to content

Commit 93641aa

Browse files
authored
Merge pull request #4 from BinaryOracle/master
Master
2 parents c6318cd + 763a69f commit 93641aa

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

src/MMLLM/庖丁解牛BLIP2.md

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,100 @@ BLIP-2 基于 BLIP 架构,利用已有的ViT 和 LLM(均冻结)+ 一个的
5151

5252
BLIP-2 框架按照 Two-Stage 策略预训练轻量级查询 Transformer 以弥合模态差距。
5353

54-
Stage 1: 不同模态数据的提取与融合。 Stage 2: 把数据转换成LLM能识别的格式。
54+
Stage 1: 不同模态数据的提取与融合。
5555

56-
![Two-Stage流程](庖丁解牛BLIP2/2.png)
57-
58-
从冻结的Image Encoder引到Vision-Language表征学习。 从冻结的LLM引到Vision-Language生成学习,实现Zero Shot图文生成。
56+
Stage 2: 把数据转换成LLM能识别的格式。
5957

58+
![Two-Stage流程](庖丁解牛BLIP2/2.png)
6059

60+
从冻结的Image Encoder引到Vision-Language表征学习。
61+
62+
从冻结的LLM引到Vision-Language生成学习,实现Zero Shot图文生成。
63+
64+
### Stage 1: Representation Learning (表征学习)
65+
66+
![tage 1: Representation Learning (表征学习)](庖丁解牛BLIP2/3.png)
67+
68+
Q-Former 由两个transformer模块组成,输入包含三部分:
69+
70+
1. 冻结参数的Image Encoder提取的图像embeddings
71+
2. Learned Queries
72+
73+
> - Queries是一组可学习的embeddings,是第一个transformer模块的input,可认为是模型参数一部分
74+
> - 推理时,Queries被用来从image encoder输出的embeddings里提取与input text最相关的视觉信息
75+
76+
3. Input Text
77+
78+
Stage 1 使用 图像-文本对 进行预训练,目标是训练好 Q-Former,**以便 Queries 可以学习到如何更好地结合文本提取图片信息**
79+
80+
对于Q-Former,一种比较好理解的方式:把Q-Former类比为一个Self-attention模块
81+
82+
- Q:learned queries
83+
- K:input text
84+
- V:image embeddings from Image Encoder
85+
86+
```python
87+
class Blip2Qformer(Blip2Base):
88+
...
89+
90+
def forward(self, samples):
91+
image = samples["image"]
92+
text = samples["text_input"]
93+
# (2,257,1408) --> attn: (2,257) ,(batch_size , seq_len , hidden_size)
94+
image_embeds = self.ln_vision(self.visual_encoder(image))
95+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
96+
image.device
97+
)
98+
# (1,32,768) --> (2,32,768) --> 共享内存
99+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
100+
101+
query_output = self.Qformer.bert(
102+
query_embeds=query_tokens,
103+
encoder_hidden_states=image_embeds,
104+
encoder_attention_mask=image_atts,
105+
use_cache=True,
106+
return_dict=True,
107+
)
108+
# BertEncoder 的 squence_output
109+
image_feats = F.normalize(
110+
self.vision_proj(query_output.last_hidden_state), dim=-1
111+
)
112+
113+
text_tokens = self.tokenizer(
114+
text,
115+
padding="max_length",
116+
truncation=True,
117+
max_length=self.max_txt_len,
118+
return_tensors="pt",
119+
).to(image.device)
120+
text_output = self.Qformer.bert(
121+
text_tokens.input_ids,
122+
attention_mask=text_tokens.attention_mask, # padding mask
123+
return_dict=True,
124+
)
125+
text_feat = F.normalize( # 取CLS TOKEN ?
126+
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
127+
)
128+
129+
###============== Image-text Contrastive ===================###
130+
...
131+
loss_itc = (
132+
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
133+
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
134+
) / 2
135+
136+
###============== Image-text Matching ===================###
137+
...
138+
loss_itm = F.cross_entropy(logits, itm_labels)
139+
140+
##================= Image Captioning ========================##
141+
...
142+
loss_lm = lm_output.loss
143+
144+
return BlipOutput(
145+
loss=loss_itc + loss_itm + loss_lm,
146+
loss_itc=loss_itc,
147+
loss_itm=loss_itm,
148+
loss_lm=loss_lm,
149+
)
150+
```

src/MMLLM/庖丁解牛BLIP2/3.png

154 KB
Loading

0 commit comments

Comments
 (0)