@@ -310,7 +310,6 @@ class PointRefer(nn.Module):
310
310
_3daffordance = torch.einsum(' blc,bcn->bln' , t_feat, up_sample)
311
311
_3daffordance = _3daffordance.sum(1 )/ (t_mask.float().sum(1 ).unsqueeze(- 1 ))
312
312
_3daffordance = torch.sigmoid(_3daffordance)
313
- # logits = self.cls_head(p_3[1].mean(-1))
314
313
return _3daffordance.squeeze(- 1 )
315
314
```
316
315
> 论文中所给的模型架构图中的Encoder layer指的是PointNet++中提供的PointNetSetAbstractionMsg多尺度分组点集特征抽取类
@@ -578,6 +577,8 @@ class TransformerDecoderLayer(nn.Module):
578
577
query_pos : Optional[Tensor] = None ):
579
578
# 1. Affordance Query = 问题嵌入(Question Embedding)X + 可学习的位置编码(Learnable Position Embeddings)
580
579
# 这里tgt就是Roberta编码得到的文本特征嵌入向量
580
+ # 使用 X 作为初始输入,确保每个 query 都带有原始语言上下文;
581
+ # 如果只用 learnable embeddings,模型将完全依赖随机初始化的参数去“猜”语言含义,效率极低;
581
582
q = k = self .with_pos_embed(tgt, query_pos)
582
583
# 2. 自注意力机制: 让每个 query 不仅理解自己的语义,还能感知其他 query 的信息,从而形成更完整的语言上下文理解。
583
584
tgt2 = self .self_attn(q, k, value = tgt, attn_mask = tgt_mask,
@@ -599,6 +600,24 @@ class TransformerDecoderLayer(nn.Module):
599
600
tgt = tgt + self .dropout3(tgt2)
600
601
tgt = self .norm3(tgt) # (b,l,c)
601
602
return tgt
603
+
604
+ class PointRefer (nn .Module ):
605
+
606
+ def forward (self , text , xyz ):
607
+ ...
608
+ # 3. Referred Point Decoding过程
609
+ # 3.1 利用一组问题条件化的 affordance queries 通过 Transformer 解码器与点云特征交互 ,生成一组动态卷积核(dynamic kernels)
610
+ t_feat = self .decoder(t_feat, up_sample.transpose(- 2 , - 1 ), tgt_key_padding_mask = t_mask, query_pos = self .pos1d)
611
+ # 3.2 对无效 token(padding)做掩码操作,防止其影响后续计算。
612
+ t_feat *= t_mask.unsqueeze(- 1 ).float()
613
+ # 3.3 执行 动态卷积(Dynamic Convolution) 操作,用增强后的语言查询去“扫描”点云特征图 (b,l,n)
614
+ _3daffordance = torch.einsum(' blc,bcn->bln' , t_feat, up_sample)
615
+ # 3.4 对affordance query的响应图进行平均池化,融合所有 affordance query 的得分结果。 (b,n)
616
+ _3daffordance = _3daffordance.sum(1 )/ (t_mask.float().sum(1 ).unsqueeze(- 1 ))
617
+ # 3.5 将响应值映射到 [0, 1] 区间,表示每个点属于目标功能区域的概率。 (b,n)
618
+ _3daffordance = torch.sigmoid(_3daffordance)
619
+ return _3daffordance # (b,n)
602
620
```
603
621
604
622
623
+
0 commit comments