@@ -120,7 +120,8 @@ class LMAffordance3D(Blip2Base):
120
120
# 通过适配器层将 LLM 输出映射回合适维度
121
121
hidden_states = self .adapter_down(hidden_states) # shape: [B, NS + NL, CS]
122
122
123
- # 分割出 instructional feature 和 semantic feature
123
+ # 分割出 semantic feature 和 instructional feature
124
+ # 视觉语义特征 和 语言指令理解特征
124
125
semantic_feature, instructional_feature = torch.split(
125
126
hidden_states,
126
127
split_size_or_sections = spatial_feature.size(1 ),
@@ -323,8 +324,118 @@ def concat_input(self, input_embeds, input_atts, multi_embeds, image_atts=None):
323
324
)
324
325
```
325
326
327
+ ### Step 9: 解码器融合所有特征以预测可操作性特征
326
328
329
+ ``` python
330
+ class Affordance_Decoder (nn .Module ):
331
+ def __init__ (self , emb_dim , proj_dim ):
332
+ super ().__init__ ()
333
+ self .emb_dim = emb_dim
334
+ self .proj_dim = proj_dim
335
+ self .cross_atten = Cross_Attention(emb_dim = self .emb_dim, proj_dim = self .proj_dim)
336
+ self .fusion = nn.Sequential(
337
+ nn.Conv1d(2 * self .emb_dim, self .emb_dim, 1 , 1 ),
338
+ nn.BatchNorm1d(self .emb_dim),
339
+ nn.ReLU()
340
+ )
341
+
342
+ def forward (self , query , key , value ):
343
+ '''
344
+ query: [B, N_p + N_i, C] -> spatial_feature (query)
345
+ key: [B, N_l, C] -> instructional_feature (key)
346
+ value: [B, N_l, C] -> semantic_feature (value)
347
+ '''
348
+ B, _, C = query.size()
327
349
350
+ # 调整 key 和 value 的形状为 [B, C, N_l]
351
+ key = key.view(B, C, - 1 ) # [B, C, N_l]
352
+ value = value.view(B, C, - 1 ) # [B, C, N_l]
353
+
354
+ # 使用 cross attention 获取两个注意力加权结果
355
+ Theta_1, Theta_2 = self .cross_atten(query, key.mT, value.mT)
356
+
357
+ # 将两个注意力输出拼接在一起
358
+ joint_context = torch.cat((Theta_1.mT, Theta_2.mT), dim = 1 ) # [B, 2C, N_p + N_i]
359
+
360
+ # 使用 Conv1D 融合通道信息
361
+ affordance = self .fusion(joint_context) # [B, C, N_p + N_i]
362
+
363
+ # 调整输出格式为 [B, N_p + N_i, C]
364
+ affordance = affordance.permute(0 , 2 , 1 ) # [B, N_p + N_i, C]
365
+
366
+ return affordance
367
+ ```
368
+ ``` python
369
+ class Cross_Attention (nn .Module ):
370
+ def __init__ (self , emb_dim , proj_dim ):
371
+ """
372
+ 多模态交叉注意力模块(Cross-Attention Module),
373
+ 用于融合来自语言模型的不同语义信息,增强空间特征表达。
374
+
375
+ Args:
376
+ emb_dim: 输入特征维度(embedding dimension),例如 LLM 的 hidden size(如 4096)
377
+ proj_dim: 投影维度,用于降低计算复杂度,在 attention 中使用
378
+ """
379
+ super ().__init__ ()
380
+ self .emb_dim = emb_dim
381
+ self .proj_dim = proj_dim
382
+
383
+ # 定义投影层,将输入映射到低维空间以进行 attention 计算
384
+ self .proj_q = nn.Linear(self .emb_dim, proj_dim) # query 投影
385
+ self .proj_sk = nn.Linear(self .emb_dim, proj_dim) # sub key 投影
386
+ self .proj_sv = nn.Linear(self .emb_dim, proj_dim) # sub value 投影
387
+ self .proj_ek = nn.Linear(self .emb_dim, proj_dim) # scene key 投影
388
+ self .proj_ev = nn.Linear(self .emb_dim, proj_dim) # scene value 投影
389
+
390
+ # 缩放因子,用于 attention 分数归一化
391
+ self .scale = self .proj_dim ** (- 0.5 )
392
+
393
+ # 层归一化(LayerNorm),用于稳定训练过程
394
+ self .layernorm = nn.LayerNorm(self .emb_dim)
395
+
396
+ def forward (self , obj , sub , scene ):
397
+ """
398
+ 执行交叉注意力机制,融合不同来源的信息:
399
+ - obj: 空间特征(spatial feature),作为 query;
400
+ - sub: 指令理解特征(instructional feature),作为第一个 attention 的 key 和 value;
401
+ - scene: 视觉语义特征(semantic feature),作为第二个 attention 的 key 和 value;
402
+
403
+ Args:
404
+ obj: [B, N_p + HW, C] → spatial_feature(query 来源)
405
+ sub: [B, HW, C] → instructional_feature(key/value 来源之一)
406
+ scene: [B, HW, C] → semantic_feature(key/value 来源之二)
407
+
408
+ Returns:
409
+ I_1: 经过 attention 加权后的输出(第一分支)
410
+ I_2: 经过 attention 加权后的输出(第二分支)
411
+ """
412
+
413
+ B, seq_length, C = obj.size() # 获取 batch size 和通道维度
414
+
415
+ # 将输入分别投影到低维空间,便于后续 attention 计算
416
+ query = self .proj_q(obj) # [B, N_q, proj_dim]
417
+ s_key = self .proj_sk(sub) # [B, N_i, proj_dim]
418
+ s_value = self .proj_sv(sub) # [B, N_i, proj_dim]
419
+
420
+ e_key = self .proj_ek(scene) # [B, N_e, proj_dim]
421
+ e_value = self .proj_ev(scene) # [B, N_e, proj_dim]
422
+
423
+ # 第一个 cross attention:使用 sub 的 key 和 value 增强 query
424
+ atten_I1 = torch.bmm(query, s_key.mT) * self .scale # [B, N_q, N_i]
425
+ atten_I1 = atten_I1.softmax(dim = - 1 ) # softmax 归一化
426
+ I_1 = torch.bmm(atten_I1, s_value) # [B, N_q, proj_dim]
427
+
428
+ # 第二个 cross attention:使用 scene 的 key 和 value 增强 query
429
+ atten_I2 = torch.bmm(query, e_key.mT) * self .scale # [B, N_q, N_e]
430
+ atten_I2 = atten_I2.softmax(dim = - 1 )
431
+ I_2 = torch.bmm(atten_I2, e_value) # [B, N_q, proj_dim]
432
+
433
+ # 使用残差连接 + LayerNorm 增强稳定性
434
+ I_1 = self .layernorm(obj + I_1) # [B, N_q, emb_dim]
435
+ I_2 = self .layernorm(obj + I_2) # [B, N_q, emb_dim]
436
+
437
+ return I_1, I_2
438
+ ```
328
439
329
440
330
441
0 commit comments