@@ -281,28 +281,32 @@ class PointRefer(nn.Module):
281
281
'''
282
282
283
283
B, C, N = xyz.size()
284
- # 使用RoBert编码文本 ,使用PointNet++编码点云
284
+ # 1. Encoding 过程
285
+ # 1.1 Language Encoding 使用RoBert编码文本
285
286
t_feat, t_mask = self .forward_text(list (text), xyz.device) # [batch, q_len, d_model]
287
+ # 1.2 BackBone Encoding 使用PointNet++编码点云
286
288
F_p_wise = self .point_encoder(xyz)
287
289
288
290
"""
289
291
Decoding
290
292
"""
291
- p_0, p_1, p_2, p_3 = F_p_wise # 每个局部区域点坐标点,每个局部区域特征
292
- p_3[1 ] = self .gpb(t_feat, p_3[1 ].transpose(- 2 , - 1 )).transpose(- 2 , - 1 ) # 每个区域特征充分和文本信息进行融合
293
- up_sample = self .fp3(p_2[0 ], p_3[0 ], p_2[1 ], p_3[1 ]) # [B, emb_dim, npoint_sa2] 特征传播
293
+ # 1.3 PointNet++ 逐级做点集抽象得到的每层的点集坐标和点集特征集合
294
+ p_0, p_1, p_2, p_3 = F_p_wise
295
+
296
+ # 2.Backbone Decoding过程
297
+ # 2.1 点集集合中每个点的特征和文本特征信息进行融合
298
+ p_3[1 ] = self .gpb(t_feat, p_3[1 ].transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
299
+ # 2.2 PointNet++ 特征传播阶段: 上采样过程中,上一层点集中的点特征重建过程中,充分吸收了高级区域抽象特征和文本特征
300
+ up_sample = self .fp3(p_2[0 ], p_3[0 ], p_2[1 ], p_3[1 ]) # [B, emb_dim, npoint_sa2]
294
301
295
302
up_sample = self .gpb(t_feat, up_sample.transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
296
303
up_sample = self .fp2(p_1[0 ], p_2[0 ], p_1[1 ], up_sample) # [B, emb_dim, npoint_sa1]
297
304
298
- up_sample = self .gpb(t_feat, up_sample.transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
305
+ up_sample = self .gpb(t_feat, up_sample.transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
306
+ # 2.3 特征传播阶段结束: 一步步重建回原始点数量 128->256->512->1024->2048
299
307
up_sample = self .fp1(p_0[0 ], p_1[0 ], torch.cat([p_0[0 ], p_0[1 ]],1 ), up_sample) # [B, emb_dim, N]
300
-
301
- # t_feat = t_feat.sum(1)/(t_mask.float().sum(1).unsqueeze(-1))
302
- # t_feat = t_feat.unsqueeze(1).repeat(1, self.n_groups,1)
303
- # t_feat += self.pos1d
304
-
305
- # print(t_feat.shape, up_sample.shape)
308
+
309
+ # 3. Referred Point Decoding过程
306
310
t_feat = self .decoder(t_feat, up_sample.transpose(- 2 , - 1 ), tgt_key_padding_mask = t_mask, query_pos = self .pos1d) # b,l,c
307
311
t_feat *= t_mask.unsqueeze(- 1 ).float()
308
312
_3daffordance = torch.einsum(' blc,bcn->bln' , t_feat, up_sample)
0 commit comments