Skip to content

Commit 303e957

Browse files
committed
updates
1 parent 9e62197 commit 303e957

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

src/3DVL/简析PointNet++.md

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ class PointNetSetAbstraction(nn.Module):
168168
new_xyz = new_xyz.permute(0, 2, 1) # [B, C, npoint]
169169
return new_xyz, new_points # 查询点的位置(质心) , 每个查询点点局部特征。
170170
```
171-
172171
sample_and_group 这个函数的作用是从输入点云中:
173172
- 采样一些关键点
174173
- 为每个关键点构建局部邻域(局部区域)
@@ -305,4 +304,81 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
305304
group_idx[mask] = group_first[mask]
306305
return group_idx # (batch,npoint,nsample)
307306
```
307+
sample_and_group_all 函数的作用是将整个点云视为一个“大局部区域”,不进行采样,直接对所有点进行特征提取,用于 PointNet++ 中的全局特征学习。
308+
309+
```python
310+
def sample_and_group_all(xyz, points):
311+
"""
312+
Input:
313+
xyz: input points position data, [B, N, 3], 点云坐标数据
314+
points: input points data, [B, N, D], 点云的额外特征(如法线、颜色等)
315+
Return:
316+
new_xyz: sampled points position data, [B, 1, 3]
317+
new_points: sampled points data, [B, 1, N, 3+D]
318+
"""
319+
device = xyz.device
320+
B, N, C = xyz.shape
321+
# 创建一个全零点作为“质心”
322+
# 虽然这个点没有实际意义,但它是为了统一接口设计的一个占位符
323+
new_xyz = torch.zeros(B, 1, C).to(device)
324+
# 把原始点云 reshape 成一个大的局部区域
325+
grouped_xyz = xyz.view(B, 1, N, C)
326+
# 如果有额外特征(比如法线、颜色),也一并加入
327+
if points is not None:
328+
# 终输出的 new_points 是 [B, 1, N, 3+D],代表每个 batch 中只有一组“大区域”的点及其特征
329+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
330+
else:
331+
new_points = grouped_xyz
332+
return new_xyz, new_points # 全局质心点(0 位置), 所有点组成的局部区域
333+
```
308334

335+
### 单尺度分组分类模型
336+
337+
PointNet++ 的 单尺度分组(SSG)架构 ,通过多层 Set Abstraction 提取点云的层次化特征,并最终输出分类结果。
338+
339+
> Single-Scale Grouping (SSG)
340+
341+
代码实现如下:
342+
343+
```python
344+
# pointnet2_cls_ssg.py
345+
class get_model(nn.Module):
346+
# num_class: 输出类别数
347+
# normal_channel: 是否包含法线信息(默认有 (x,y,z,nx,ny,nz),否则只有 (x,y,z))
348+
def __init__(self,num_class,normal_channel=True):
349+
super(get_model, self).__init__()
350+
in_channel = 6 if normal_channel else 3
351+
self.normal_channel = normal_channel
352+
# PointNet++ 的核心就是逐层提取局部特征。这里的三个 SA 层构成了一个 三层分层特征学习结构 :
353+
self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
354+
self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
355+
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
356+
self.fc1 = nn.Linear(1024, 512)
357+
self.bn1 = nn.BatchNorm1d(512)
358+
self.drop1 = nn.Dropout(0.4)
359+
self.fc2 = nn.Linear(512, 256)
360+
self.bn2 = nn.BatchNorm1d(256)
361+
self.drop2 = nn.Dropout(0.4)
362+
self.fc3 = nn.Linear(256, num_class)
363+
364+
def forward(self, xyz):
365+
B, _, _ = xyz.shape
366+
if self.normal_channel:
367+
norm = xyz[:, 3:, :]
368+
xyz = xyz[:, :3, :]
369+
else:
370+
norm = None
371+
l1_xyz, l1_points = self.sa1(xyz, norm)
372+
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
373+
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
374+
x = l3_points.view(B, 1024)
375+
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
376+
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
377+
x = self.fc3(x)
378+
x = F.log_softmax(x, -1)
379+
380+
return x, l3_points
381+
```
382+
383+
384+

0 commit comments

Comments
 (0)