@@ -441,44 +441,81 @@ class Cross_Attention(nn.Module):
441
441
``` python
442
442
class Head (nn .Module ):
443
443
def __init__ (self , additional_channel , emb_dim , N_p , N_raw ):
444
+ """
445
+ Head 模块用于最终的 3D 可操作性(affordance)预测。
446
+ 它接收来自编码器和解码器的特征,并通过多尺度上采样与融合,
447
+ 输出每个点云点的 affordance 热图(heatmap),表示该点是否具有可操作性。
448
+
449
+ Args:
450
+ additional_channel: 额外通道数,例如法向量、颜色等信息
451
+ emb_dim: 特征维度(embedding dimension)
452
+ N_p: point cloud token 数量(如 64)
453
+ N_raw: 原始点云数量(如 2048)
454
+
455
+ Notes:
456
+ - 使用 PointNetFeaturePropagation 进行逐级上采样;
457
+ - 结合全局池化增强语义表达;
458
+ - 最终使用 MLP + Sigmoid 输出每个点的 affordance score;
459
+ """
444
460
super ().__init__ ()
445
461
446
462
self .emb_dim = emb_dim
447
- self .N_p = N_p
448
- self .N_raw = N_raw
449
- # upsample
463
+ self .N_p = N_p # point cloud token 数量
464
+ self .N_raw = N_raw # 原始点云数量(如 2048)
465
+
466
+ # 多尺度上采样模块:PointNetFeaturePropagation
467
+ # fp3: 输入为 [512 + emb_dim],输出为 512 维度
450
468
self .fp3 = PointNetFeaturePropagation(in_channel = 512 + self .emb_dim, mlp = [768 , 512 ])
451
469
self .fp2 = PointNetFeaturePropagation(in_channel = 832 , mlp = [768 , 512 ])
452
470
self .fp1 = PointNetFeaturePropagation(in_channel = 518 + additional_channel, mlp = [512 , 512 ])
471
+
472
+ # 全局平均池化层,压缩时间/空间维度
453
473
self .pool = nn.AdaptiveAvgPool1d(1 )
454
474
475
+ # 最终输出头:MLP + BatchNorm + ReLU + Sigmoid
455
476
self .out_head = nn.Sequential(
456
477
nn.Linear(self .emb_dim, self .emb_dim // 8 ),
457
- nn.BatchNorm1d(self .N_raw),
478
+ nn.BatchNorm1d(self .N_raw), # 对点数维度做 BN
458
479
nn.ReLU(),
459
- nn.Linear(self .emb_dim // 8 , 1 ),
460
- nn.Sigmoid()
480
+ nn.Linear(self .emb_dim // 8 , 1 ), # 输出每个点的 affordance score
481
+ nn.Sigmoid() # 输出范围 [0,1],表示概率
461
482
)
462
483
463
484
def forward (self , multi_feature , affordance_feature , encoder_p ):
464
- '''
465
- multi_feature ---> [B, N_p + N_i, C]
466
- affordance_feature ---> [B, N_p + N_i, C]
467
- encoder_p ---> [Hierarchy feature]
468
- '''
469
- B,N,C = multi_feature.size()
485
+ """
486
+ 执行 Head 模块的前向传播,生成最终的 3D affordance heatmap。
487
+
488
+ Args:
489
+ multi_feature: [B, N_p + N_i, C] → 来自 Vision-Language Model 的拼接特征
490
+ affordance_feature: [B, N_p + N_i, C] → 来自 decoder 的可操作性特征
491
+ encoder_p: [p0, p1, p2, p3] → 编码器不同层级的点云特征
492
+
493
+ Returns:
494
+ out: [B, N_raw, 1] → 每个点的 affordance score(概率值)
495
+ """
496
+
497
+ B, N, C = multi_feature.size()
498
+
499
+ # 解包编码器输出的不同层级特征
470
500
p_0, p_1, p_2, p_3 = encoder_p
471
- P_align, _ = torch.split(multi_feature, split_size_or_sections = self .N_p, dim = 1 ) # [B, N_p, C] --- [B, N_i, C]
472
- F_pa, _ = torch.split(affordance_feature, split_size_or_sections = self .N_p, dim = 1 ) # [B, N_p, C] --- [B, N_i, C]
473
501
474
- up_sample = self .fp3(p_2[0 ], p_3[0 ], p_2[1 ], P_align.mT) # [B, emb_dim, npoint_sa2]
475
- up_sample = self .fp2(p_1[0 ], p_2[0 ], p_1[1 ], up_sample) # [B, emb_dim, npoint_sa1]
476
- up_sample = self .fp1(p_0[0 ], p_1[0 ], torch.cat([p_0[0 ], p_0[1 ]],1 ), up_sample) # [B, emb_dim, N_raw]
477
- F_pa_pool = self .pool(F_pa.mT) # [B, emb_dim, 1]
478
-
479
- affordance = up_sample * F_pa_pool.expand(- 1 ,- 1 ,self .N_raw) # [B, emb_dim, 2048]
480
-
481
- out = self .out_head(affordance.mT) # [B, 2048, 1]
502
+ # 从 multi_feature 和 affordance_feature 中提取 point cloud token 部分
503
+ P_align, _ = torch.split(multi_feature, split_size_or_sections = self .N_p, dim = 1 )
504
+ F_pa, _ = torch.split(affordance_feature, split_size_or_sections = self .N_p, dim = 1 )
505
+
506
+ # 上采样过程:fp3 -> fp2 -> fp1
507
+ up_sample = self .fp3(p_2[0 ], p_3[0 ], p_2[1 ], P_align.mT) # [B, emb_dim, npoint_sa2]
508
+ up_sample = self .fp2(p_1[0 ], p_2[0 ], p_1[1 ], up_sample) # [B, emb_dim, npoint_sa1]
509
+ up_sample = self .fp1(p_0[0 ], p_1[0 ], torch.cat([p_0[0 ], p_0[1 ]], 1 ), up_sample) # [B, emb_dim, N_raw]
510
+
511
+ # 对 F_pa 做全局池化,得到一个全局语义向量
512
+ F_pa_pool = self .pool(F_pa.mT) # [B, emb_dim, 1]
513
+
514
+ # 将全局语义向量扩展回原始点云数量,实现 feature-wise attention
515
+ affordance = up_sample * F_pa_pool.expand(- 1 , - 1 , self .N_raw) # [B, emb_dim, N_raw]
516
+
517
+ # 输出 head:将特征映射到 0~1 的概率值,表示每个点是否具有可操作性
518
+ out = self .out_head(affordance.mT) # [B, N_raw, 1]
482
519
483
520
return out
484
521
```
0 commit comments