@@ -263,3 +263,51 @@ PointRefer 包括以下核心模块:
263
263
- 生成动态卷积核(dynamic kernels);
264
264
- 最终通过卷积操作生成分割掩码;
265
265
266
+ ![ PointRefer模型结构图] ( LASO/5.png )
267
+
268
+ PointRefer 前向传播过程如下:
269
+
270
+ ``` python
271
+ class PointRefer (nn .Module ):
272
+
273
+ # 传入question文本 和 point点云数据
274
+ def forward (self , text , xyz ):
275
+
276
+ '''
277
+ text: [B, L, 768]
278
+ xyz: [B, 3, 2048]
279
+ sub_box: bounding box of the interactive subject
280
+ obj_box: bounding box of the interactive object
281
+ '''
282
+
283
+ B, C, N = xyz.size()
284
+ # 使用RoBert编码文本 ,使用PointNet++编码点云
285
+ t_feat, t_mask = self .forward_text(list (text), xyz.device) # [batch, q_len, d_model]
286
+ F_p_wise = self .point_encoder(xyz)
287
+
288
+ """
289
+ Decoding
290
+ """
291
+ p_0, p_1, p_2, p_3 = F_p_wise # 每个局部区域点坐标点,每个局部区域特征
292
+ p_3[1 ] = self .gpb(t_feat, p_3[1 ].transpose(- 2 , - 1 )).transpose(- 2 , - 1 ) # 每个区域特征充分和文本信息进行融合
293
+ up_sample = self .fp3(p_2[0 ], p_3[0 ], p_2[1 ], p_3[1 ]) # [B, emb_dim, npoint_sa2] 特征传播
294
+
295
+ up_sample = self .gpb(t_feat, up_sample.transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
296
+ up_sample = self .fp2(p_1[0 ], p_2[0 ], p_1[1 ], up_sample) # [B, emb_dim, npoint_sa1]
297
+
298
+ up_sample = self .gpb(t_feat, up_sample.transpose(- 2 , - 1 )).transpose(- 2 , - 1 )
299
+ up_sample = self .fp1(p_0[0 ], p_1[0 ], torch.cat([p_0[0 ], p_0[1 ]],1 ), up_sample) # [B, emb_dim, N]
300
+
301
+ # t_feat = t_feat.sum(1)/(t_mask.float().sum(1).unsqueeze(-1))
302
+ # t_feat = t_feat.unsqueeze(1).repeat(1, self.n_groups,1)
303
+ # t_feat += self.pos1d
304
+
305
+ # print(t_feat.shape, up_sample.shape)
306
+ t_feat = self .decoder(t_feat, up_sample.transpose(- 2 , - 1 ), tgt_key_padding_mask = t_mask, query_pos = self .pos1d) # b,l,c
307
+ t_feat *= t_mask.unsqueeze(- 1 ).float()
308
+ _3daffordance = torch.einsum(' blc,bcn->bln' , t_feat, up_sample)
309
+ _3daffordance = _3daffordance.sum(1 )/ (t_mask.float().sum(1 ).unsqueeze(- 1 ))
310
+ _3daffordance = torch.sigmoid(_3daffordance)
311
+ # logits = self.cls_head(p_3[1].mean(-1))
312
+ return _3daffordance.squeeze(- 1 )
313
+ ```
0 commit comments