Skip to content

Commit 5fa6f61

Browse files
committed
updates
1 parent 303e957 commit 5fa6f61

File tree

3 files changed

+64
-55
lines changed

3 files changed

+64
-55
lines changed

src/3DVL/简析PointNet++.md

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -113,61 +113,6 @@ FPS是一种在点云、图像处理或其他数据集中用于抽样的算法
113113

114114
#### 代码实现
115115

116-
PointNetSetAbstraction(点集抽象层) 是 PointNet++ 中的核心模块 , 它的作用是负责从输入的点云数据中采样关键点,构建它们的局部邻域区域,并通过一个小型 PointNet 提取这些区域的高维特征,从而实现点云的分层特征学习。
117-
118-
```python
119-
class PointNetSetAbstraction(nn.Module):
120-
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
121-
super(PointNetSetAbstraction, self).__init__()
122-
self.npoint = npoint # 采样的关键点数量
123-
self.radius = radius # 构建局部邻域的半径
124-
self.nsample = nsample # 每个邻域内采样的关键点数量
125-
self.mlp_convs = nn.ModuleList()
126-
self.mlp_bns = nn.ModuleList()
127-
last_channel = in_channel # 输入点的特征维度
128-
for out_channel in mlp:
129-
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
130-
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
131-
last_channel = out_channel
132-
self.group_all = group_all
133-
134-
def forward(self, xyz, points):
135-
"""
136-
Input:
137-
xyz: input points position data, [B, C, N]
138-
points: input points data, [B, D, N]
139-
Return:
140-
new_xyz: sampled points position data, [B, C, S]
141-
new_points_concat: sample points feature data, [B, D', S]
142-
"""
143-
xyz = xyz.permute(0, 2, 1) # [B, N, C]
144-
if points is not None:
145-
points = points.permute(0, 2, 1)
146-
147-
# 如果 group_all=True,则对整个点云做全局特征提取。
148-
if self.group_all:
149-
new_xyz, new_points = sample_and_group_all(xyz, points)
150-
else:
151-
# 否则使用 FPS(最远点采样)选关键点,再用 Ball Query 找出每个点的局部邻近点。
152-
# 参数: 质点数量,采样半径,采样点数量,点坐标,点额外特征
153-
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
154-
# 局部特征编码(Mini-PointNet)
155-
# new_xyz: sampled points position data, [B, npoint, C]
156-
# new_points: sampled points data, [B, npoint, nsample, C+D]
157-
# 把邻域点的数据整理成适合卷积的格式 [B, C+D, nsample, npoint]
158-
new_points = new_points.permute(0, 3, 2, 1)
159-
# 使用多个 Conv2d + BatchNorm + ReLU 层提取特征
160-
for i, conv in enumerate(self.mlp_convs):
161-
bn = self.mlp_bns[i]
162-
new_points = F.relu(bn(conv(new_points))) # [B, out_channel , nsample, npoint]
163-
164-
# 对每个局部区域内所有点的最大响应值进行池化,得到该区域的固定长度特征表示。
165-
# 在 new_points 的第 2 个维度(即每个局部邻域内的点数量维度)上做最大池化(max pooling)。
166-
# 输出形状为 [B, out_channel, nsample],即每个查询点有一个特征向量。
167-
new_points = torch.max(new_points, 2)[0]
168-
new_xyz = new_xyz.permute(0, 2, 1) # [B, C, npoint]
169-
return new_xyz, new_points # 查询点的位置(质心) , 每个查询点点局部特征。
170-
```
171116
sample_and_group 这个函数的作用是从输入点云中:
172117
- 采样一些关键点
173118
- 为每个关键点构建局部邻域(局部区域)
@@ -304,6 +249,9 @@ def query_ball_point(radius, nsample, xyz, new_xyz):
304249
group_idx[mask] = group_first[mask]
305250
return group_idx # (batch,npoint,nsample)
306251
```
252+
253+
![sample_and_group流程图](简析PointNet++/2.png)
254+
307255
sample_and_group_all 函数的作用是将整个点云视为一个“大局部区域”,不进行采样,直接对所有点进行特征提取,用于 PointNet++ 中的全局特征学习。
308256

309257
```python
@@ -331,6 +279,65 @@ def sample_and_group_all(xyz, points):
331279
new_points = grouped_xyz
332280
return new_xyz, new_points # 全局质心点(0 位置), 所有点组成的局部区域
333281
```
282+
![sample_and_group_all流程图](简析PointNet++/3.png)
283+
284+
PointNetSetAbstraction(点集抽象层) 是 PointNet++ 中的核心模块 , 它的作用是负责从输入的点云数据中采样关键点,构建它们的局部邻域区域,并通过一个小型 PointNet 提取这些区域的高维特征,从而实现点云的分层特征学习。
285+
286+
```python
287+
class PointNetSetAbstraction(nn.Module):
288+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
289+
super(PointNetSetAbstraction, self).__init__()
290+
self.npoint = npoint # 采样的关键点数量
291+
self.radius = radius # 构建局部邻域的半径
292+
self.nsample = nsample # 每个邻域内采样的关键点数量
293+
self.mlp_convs = nn.ModuleList()
294+
self.mlp_bns = nn.ModuleList()
295+
last_channel = in_channel # 输入点的特征维度
296+
for out_channel in mlp:
297+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
298+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
299+
last_channel = out_channel
300+
self.group_all = group_all
301+
302+
def forward(self, xyz, points):
303+
"""
304+
Input:
305+
xyz: input points position data, [B, C, N]
306+
points: input points data, [B, D, N]
307+
Return:
308+
new_xyz: sampled points position data, [B, C, S]
309+
new_points_concat: sample points feature data, [B, D', S]
310+
"""
311+
xyz = xyz.permute(0, 2, 1) # [B, N, C]
312+
if points is not None:
313+
points = points.permute(0, 2, 1)
314+
315+
# 如果 group_all=True,则对整个点云做全局特征提取。
316+
if self.group_all:
317+
new_xyz, new_points = sample_and_group_all(xyz, points)
318+
else:
319+
# 否则使用 FPS(最远点采样)选关键点,再用 Ball Query 找出每个点的局部邻近点。
320+
# 参数: 质点数量,采样半径,采样点数量,点坐标,点额外特征
321+
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
322+
# 局部特征编码(Mini-PointNet)
323+
# new_xyz: sampled points position data, [B, npoint, C]
324+
# new_points: sampled points data, [B, npoint, nsample, C+D]
325+
# 把邻域点的数据整理成适合卷积的格式 [B, C+D, nsample, npoint]
326+
new_points = new_points.permute(0, 3, 2, 1)
327+
# 使用多个 Conv2d + BatchNorm + ReLU 层提取特征
328+
for i, conv in enumerate(self.mlp_convs):
329+
bn = self.mlp_bns[i]
330+
new_points = F.relu(bn(conv(new_points))) # [B, out_channel , nsample, npoint]
331+
332+
# 对每个局部区域内所有点的最大响应值进行池化,得到该区域的固定长度特征表示。
333+
# 在 new_points 的第 2 个维度(即每个局部邻域内的点数量维度)上做最大池化(max pooling)。
334+
# 输出形状为 [B, out_channel, npoint],即每个查询点有一个特征向量。
335+
new_points = torch.max(new_points, 2)[0] # [B, out_channel, npoint]
336+
new_xyz = new_xyz.permute(0, 2, 1) # [B, C, npoint]
337+
return new_xyz, new_points # 查询点的位置(质心) , 每个查询点点局部特征。
338+
```
339+
最终每个采样得到的关键点所在的局部领域,都会被压缩为一个固定长度的特征向量。这个特征向量代表了这个局部区域的高维特征,它包含了这个区域内所有点的信息。
340+
334341

335342
### 单尺度分组分类模型
336343

@@ -379,6 +386,8 @@ class get_model(nn.Module):
379386

380387
return x, l3_points
381388
```
389+
完整的单尺度分组分类流程如下图所示:
390+
382391

383392

384393

src/3DVL/简析PointNet++/2.png

1.19 MB
Loading

src/3DVL/简析PointNet++/3.png

411 KB
Loading

0 commit comments

Comments
 (0)