Skip to content

Commit 3bb194a

Browse files
committed
updates
1 parent 05adf92 commit 3bb194a

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

src/3DVL/Grounding_3D_Object_Affordance.md

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ class LMAffordance3D(Blip2Base):
120120
# 通过适配器层将 LLM 输出映射回合适维度
121121
hidden_states = self.adapter_down(hidden_states) # shape: [B, NS + NL, CS]
122122

123-
# 分割出 instructional feature 和 semantic feature
123+
# 分割出 semantic feature 和 instructional feature
124+
# 视觉语义特征 和 语言指令理解特征
124125
semantic_feature, instructional_feature = torch.split(
125126
hidden_states,
126127
split_size_or_sections=spatial_feature.size(1),
@@ -323,8 +324,118 @@ def concat_input(self, input_embeds, input_atts, multi_embeds, image_atts=None):
323324
)
324325
```
325326

327+
### Step 9: 解码器融合所有特征以预测可操作性特征
326328

329+
```python
330+
class Affordance_Decoder(nn.Module):
331+
def __init__(self, emb_dim, proj_dim):
332+
super().__init__()
333+
self.emb_dim = emb_dim
334+
self.proj_dim = proj_dim
335+
self.cross_atten = Cross_Attention(emb_dim = self.emb_dim, proj_dim = self.proj_dim)
336+
self.fusion = nn.Sequential(
337+
nn.Conv1d(2*self.emb_dim, self.emb_dim, 1, 1),
338+
nn.BatchNorm1d(self.emb_dim),
339+
nn.ReLU()
340+
)
341+
342+
def forward(self, query, key, value):
343+
'''
344+
query: [B, N_p + N_i, C] -> spatial_feature (query)
345+
key: [B, N_l, C] -> instructional_feature (key)
346+
value: [B, N_l, C] -> semantic_feature (value)
347+
'''
348+
B, _, C = query.size()
327349

350+
# 调整 key 和 value 的形状为 [B, C, N_l]
351+
key = key.view(B, C, -1) # [B, C, N_l]
352+
value = value.view(B, C, -1) # [B, C, N_l]
353+
354+
# 使用 cross attention 获取两个注意力加权结果
355+
Theta_1, Theta_2 = self.cross_atten(query, key.mT, value.mT)
356+
357+
# 将两个注意力输出拼接在一起
358+
joint_context = torch.cat((Theta_1.mT, Theta_2.mT), dim=1) # [B, 2C, N_p + N_i]
359+
360+
# 使用 Conv1D 融合通道信息
361+
affordance = self.fusion(joint_context) # [B, C, N_p + N_i]
362+
363+
# 调整输出格式为 [B, N_p + N_i, C]
364+
affordance = affordance.permute(0, 2, 1) # [B, N_p + N_i, C]
365+
366+
return affordance
367+
```
368+
```python
369+
class Cross_Attention(nn.Module):
370+
def __init__(self, emb_dim, proj_dim):
371+
"""
372+
多模态交叉注意力模块(Cross-Attention Module),
373+
用于融合来自语言模型的不同语义信息,增强空间特征表达。
374+
375+
Args:
376+
emb_dim: 输入特征维度(embedding dimension),例如 LLM 的 hidden size(如 4096)
377+
proj_dim: 投影维度,用于降低计算复杂度,在 attention 中使用
378+
"""
379+
super().__init__()
380+
self.emb_dim = emb_dim
381+
self.proj_dim = proj_dim
382+
383+
# 定义投影层,将输入映射到低维空间以进行 attention 计算
384+
self.proj_q = nn.Linear(self.emb_dim, proj_dim) # query 投影
385+
self.proj_sk = nn.Linear(self.emb_dim, proj_dim) # sub key 投影
386+
self.proj_sv = nn.Linear(self.emb_dim, proj_dim) # sub value 投影
387+
self.proj_ek = nn.Linear(self.emb_dim, proj_dim) # scene key 投影
388+
self.proj_ev = nn.Linear(self.emb_dim, proj_dim) # scene value 投影
389+
390+
# 缩放因子,用于 attention 分数归一化
391+
self.scale = self.proj_dim ** (-0.5)
392+
393+
# 层归一化(LayerNorm),用于稳定训练过程
394+
self.layernorm = nn.LayerNorm(self.emb_dim)
395+
396+
def forward(self, obj, sub, scene):
397+
"""
398+
执行交叉注意力机制,融合不同来源的信息:
399+
- obj: 空间特征(spatial feature),作为 query;
400+
- sub: 指令理解特征(instructional feature),作为第一个 attention 的 key 和 value;
401+
- scene: 视觉语义特征(semantic feature),作为第二个 attention 的 key 和 value;
402+
403+
Args:
404+
obj: [B, N_p + HW, C] → spatial_feature(query 来源)
405+
sub: [B, HW, C] → instructional_feature(key/value 来源之一)
406+
scene: [B, HW, C] → semantic_feature(key/value 来源之二)
407+
408+
Returns:
409+
I_1: 经过 attention 加权后的输出(第一分支)
410+
I_2: 经过 attention 加权后的输出(第二分支)
411+
"""
412+
413+
B, seq_length, C = obj.size() # 获取 batch size 和通道维度
414+
415+
# 将输入分别投影到低维空间,便于后续 attention 计算
416+
query = self.proj_q(obj) # [B, N_q, proj_dim]
417+
s_key = self.proj_sk(sub) # [B, N_i, proj_dim]
418+
s_value = self.proj_sv(sub) # [B, N_i, proj_dim]
419+
420+
e_key = self.proj_ek(scene) # [B, N_e, proj_dim]
421+
e_value = self.proj_ev(scene) # [B, N_e, proj_dim]
422+
423+
# 第一个 cross attention:使用 sub 的 key 和 value 增强 query
424+
atten_I1 = torch.bmm(query, s_key.mT) * self.scale # [B, N_q, N_i]
425+
atten_I1 = atten_I1.softmax(dim=-1) # softmax 归一化
426+
I_1 = torch.bmm(atten_I1, s_value) # [B, N_q, proj_dim]
427+
428+
# 第二个 cross attention:使用 scene 的 key 和 value 增强 query
429+
atten_I2 = torch.bmm(query, e_key.mT) * self.scale # [B, N_q, N_e]
430+
atten_I2 = atten_I2.softmax(dim=-1)
431+
I_2 = torch.bmm(atten_I2, e_value) # [B, N_q, proj_dim]
432+
433+
# 使用残差连接 + LayerNorm 增强稳定性
434+
I_1 = self.layernorm(obj + I_1) # [B, N_q, emb_dim]
435+
I_2 = self.layernorm(obj + I_2) # [B, N_q, emb_dim]
436+
437+
return I_1, I_2
438+
```
328439

329440

330441

0 commit comments

Comments
 (0)