Skip to content

Commit 1eb5bb7

Browse files
committed
updates
1 parent 446b994 commit 1eb5bb7

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/3DVL/LASO.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,51 @@ PointRefer 包括以下核心模块:
263263
- 生成动态卷积核(dynamic kernels);
264264
- 最终通过卷积操作生成分割掩码;
265265

266+
![PointRefer模型结构图](LASO/5.png)
267+
268+
PointRefer 前向传播过程如下:
269+
270+
```python
271+
class PointRefer(nn.Module):
272+
273+
# 传入question文本 和 point点云数据
274+
def forward(self, text, xyz):
275+
276+
'''
277+
text: [B, L, 768]
278+
xyz: [B, 3, 2048]
279+
sub_box: bounding box of the interactive subject
280+
obj_box: bounding box of the interactive object
281+
'''
282+
283+
B, C, N = xyz.size()
284+
# 使用RoBert编码文本 ,使用PointNet++编码点云
285+
t_feat, t_mask = self.forward_text(list(text), xyz.device) # [batch, q_len, d_model]
286+
F_p_wise = self.point_encoder(xyz)
287+
288+
"""
289+
Decoding
290+
"""
291+
p_0, p_1, p_2, p_3 = F_p_wise # 每个局部区域点坐标点,每个局部区域特征
292+
p_3[1] = self.gpb(t_feat, p_3[1].transpose(-2, -1)).transpose(-2, -1) # 每个区域特征充分和文本信息进行融合
293+
up_sample = self.fp3(p_2[0], p_3[0], p_2[1], p_3[1]) #[B, emb_dim, npoint_sa2] 特征传播
294+
295+
up_sample = self.gpb(t_feat, up_sample.transpose(-2, -1)).transpose(-2, -1)
296+
up_sample = self.fp2(p_1[0], p_2[0], p_1[1], up_sample) #[B, emb_dim, npoint_sa1]
297+
298+
up_sample = self.gpb(t_feat, up_sample.transpose(-2, -1)).transpose(-2, -1)
299+
up_sample = self.fp1(p_0[0], p_1[0], torch.cat([p_0[0], p_0[1]],1), up_sample) #[B, emb_dim, N]
300+
301+
# t_feat = t_feat.sum(1)/(t_mask.float().sum(1).unsqueeze(-1))
302+
# t_feat = t_feat.unsqueeze(1).repeat(1, self.n_groups,1)
303+
# t_feat += self.pos1d
304+
305+
# print(t_feat.shape, up_sample.shape)
306+
t_feat = self.decoder(t_feat, up_sample.transpose(-2, -1), tgt_key_padding_mask=t_mask, query_pos=self.pos1d) # b,l,c
307+
t_feat *= t_mask.unsqueeze(-1).float()
308+
_3daffordance = torch.einsum('blc,bcn->bln', t_feat, up_sample)
309+
_3daffordance = _3daffordance.sum(1)/(t_mask.float().sum(1).unsqueeze(-1))
310+
_3daffordance = torch.sigmoid(_3daffordance)
311+
# logits = self.cls_head(p_3[1].mean(-1))
312+
return _3daffordance.squeeze(-1)
313+
```

src/3DVL/LASO/5.png

409 KB
Loading

0 commit comments

Comments
 (0)