Skip to content

Commit ed3d45c

Browse files
committed
updates
1 parent 1eb5bb7 commit ed3d45c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

src/3DVL/LASO.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,28 +281,32 @@ class PointRefer(nn.Module):
281281
'''
282282

283283
B, C, N = xyz.size()
284-
# 使用RoBert编码文本 ,使用PointNet++编码点云
284+
# 1. Encoding 过程
285+
# 1.1 Language Encoding 使用RoBert编码文本
285286
t_feat, t_mask = self.forward_text(list(text), xyz.device) # [batch, q_len, d_model]
287+
# 1.2 BackBone Encoding 使用PointNet++编码点云
286288
F_p_wise = self.point_encoder(xyz)
287289

288290
"""
289291
Decoding
290292
"""
291-
p_0, p_1, p_2, p_3 = F_p_wise # 每个局部区域点坐标点,每个局部区域特征
292-
p_3[1] = self.gpb(t_feat, p_3[1].transpose(-2, -1)).transpose(-2, -1) # 每个区域特征充分和文本信息进行融合
293-
up_sample = self.fp3(p_2[0], p_3[0], p_2[1], p_3[1]) #[B, emb_dim, npoint_sa2] 特征传播
293+
# 1.3 PointNet++ 逐级做点集抽象得到的每层的点集坐标和点集特征集合
294+
p_0, p_1, p_2, p_3 = F_p_wise
295+
296+
# 2.Backbone Decoding过程
297+
# 2.1 点集集合中每个点的特征和文本特征信息进行融合
298+
p_3[1] = self.gpb(t_feat, p_3[1].transpose(-2, -1)).transpose(-2, -1)
299+
# 2.2 PointNet++ 特征传播阶段: 上采样过程中,上一层点集中的点特征重建过程中,充分吸收了高级区域抽象特征和文本特征
300+
up_sample = self.fp3(p_2[0], p_3[0], p_2[1], p_3[1]) #[B, emb_dim, npoint_sa2]
294301

295302
up_sample = self.gpb(t_feat, up_sample.transpose(-2, -1)).transpose(-2, -1)
296303
up_sample = self.fp2(p_1[0], p_2[0], p_1[1], up_sample) #[B, emb_dim, npoint_sa1]
297304

298-
up_sample = self.gpb(t_feat, up_sample.transpose(-2, -1)).transpose(-2, -1)
305+
up_sample = self.gpb(t_feat, up_sample.transpose(-2, -1)).transpose(-2, -1)
306+
# 2.3 特征传播阶段结束: 一步步重建回原始点数量 128->256->512->1024->2048
299307
up_sample = self.fp1(p_0[0], p_1[0], torch.cat([p_0[0], p_0[1]],1), up_sample) #[B, emb_dim, N]
300-
301-
# t_feat = t_feat.sum(1)/(t_mask.float().sum(1).unsqueeze(-1))
302-
# t_feat = t_feat.unsqueeze(1).repeat(1, self.n_groups,1)
303-
# t_feat += self.pos1d
304-
305-
# print(t_feat.shape, up_sample.shape)
308+
309+
# 3. Referred Point Decoding过程
306310
t_feat = self.decoder(t_feat, up_sample.transpose(-2, -1), tgt_key_padding_mask=t_mask, query_pos=self.pos1d) # b,l,c
307311
t_feat *= t_mask.unsqueeze(-1).float()
308312
_3daffordance = torch.einsum('blc,bcn->bln', t_feat, up_sample)

0 commit comments

Comments
 (0)