Skip to content

Commit a53c78a

Browse files
committed
updates
1 parent 9cd1181 commit a53c78a

File tree

1 file changed

+189
-18
lines changed

1 file changed

+189
-18
lines changed

src/3DVL/LASO.md

Lines changed: 189 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -319,32 +319,109 @@ class PointRefer(nn.Module):
319319
320320
### AFM 自适应融合模块
321321

322-
```python
323-
class GPBlock(nn.Module):
324-
# q: 文本特征 (b, l, c) , x: 点集特征集合 (b, n, c)
325-
def forward(self, q, x, q_mask=None):
326-
# 1. 注意力: 文本特征作为query,从点集特征集合中提取相关区域特征
327-
gt = self.group_layer(query=q, key=x, value=x)
328-
if q_mask is not None:
329-
gt *= q_mask.unsqueeze(-1)
330-
# 2. MLP: 做token维度和channel维度的信息融合
331-
gt = self.mixer(gt) + self.drop(gt)
332-
ungroup_tokens = self.un_group_layer(query=x, key=gt, value=gt, key_padding_mask=q_mask)
333-
return ungroup_tokens
322+
在 LASO 任务中,模型需要根据自然语言问题(如 “Where to grasp?”)识别点云中的功能区域。由于目标功能区域的尺度、形状多样,传统方法难以适应不同情况。为此,作者设计了 AFM 模块,以增强 PointNet++ 解码过程中点特征的语言引导能力。
323+
324+
AFM 的目标是:在不同解码阶段注入语言线索(text clues),将文本语义信息与点云特征进行跨模态融合,逐步以自上而下的方式细化点特征图,从而提升模型对多尺度、多形状的功能区域的感知能力。
334325

326+
AFM 遵循一个 **瓶颈式架构(bottleneck architecture)**,包含三个关键步骤:
327+
328+
1. **Grouping(分组)**
329+
2. **Mixing(混合)**
330+
3. **Ungrouping(解组)**
331+
332+
这三个步骤构成了一个完整的跨模态融合流程。
333+
334+
#### 1️⃣ Grouping:文本引导的点特征分组
335+
336+
输入:
337+
- `X ∈ R^{L×d}`:问题编码后的文本特征(由 RoBERTa 编码得到)
338+
- `P ∈ R^{T×d}`:某一层解码器输出的点特征,其中 T 表示该层点数
339+
340+
处理过程:
341+
- 使用一个轻量级的交叉注意力模块,将文本特征作为查询(query),点特征作为键(key)和值(value),输出分组标记 G:
342+
343+
$$
344+
G = \text{Attention}(X, W_1P, P) + X
345+
$$
346+
$$
347+
\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
348+
$$
349+
350+
其中:
351+
- $W_1$ 是一个线性变换;
352+
- 注意力机制使得每个文本 token 对应一组相关的点特征;
353+
- 分组操作实现了“语言引导的点特征筛选”。
354+
355+
> 重点是如何理解这里的分组: 每个文本Token询问所有点Key后,知道了哪些点跟自身的相关度更大,因此加权融合的时候,侧重于给这些点的特征分配更大的融合权重。
356+
357+
这部分代码实现如下:
358+
359+
```python
335360
# group_layer 的实现
336361
class LightGroupAttnBlock(nn.Module):
337-
362+
363+
# query 是RoBerta编码后的文本特征 , (b,l,c)
364+
# key和value都是点云特征 , (b,n,c)
338365
def forward(self, query, key, value, q_mask=None):
339366
def _inner_forward(query, key, value):
340367
q = self.norm_query(query)
341368
k = q if self.key_is_query else self.norm_key(key)
342369
v = k if self.value_is_key else self.norm_value(value)
343-
x = self.attn(q, k, v, q_mask) + self.drop(q) # 残差连接
370+
# 让每个语言 token 去关注点云中最相关的区域,并在此基础上强化自身的语义表达。
371+
# 加上原始 X 是一种残差连接(Residual Connection),可以确保语言语义不会丢失。
372+
x = self.attn(q, k, v, q_mask) + self.drop(q)
344373
return x
345-
374+
346375
return _inner_forward(query, key, value)
376+
```
377+
378+
---
379+
380+
#### 2️⃣ Mixing:MLP-Mixer 进行组内和通道间的信息混合
381+
382+
MLP-Mixer 是一种 基于 MLP(多层感知机)的视觉模型架构 ,由 Google Research 在 2021 年提出。它不使用任何注意力机制,而是通过 空间混合(mixing)和通道混合(mixing)操作 来实现全局信息建模。
383+
384+
> [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
385+
386+
MLP-Mixer 的核心思想是:用 MLP 替代 Transformer 中的自注意力机制 ,从而减少计算复杂度并保持性能。
387+
388+
1. Token-mixing MLP
347389

390+
- 对所有点/patch 的相同通道进行混合;
391+
- 相当于跨空间位置的信息交换;
392+
- 类似于 CNN 中的空间卷积;
393+
394+
2. Channel-mixing MLP
395+
396+
- 对每个 token 的所有通道进行处理;
397+
- 提取更高级的特征表示;
398+
- 类似于传统的全连接层或 1x1 卷积;
399+
400+
这两个操作交替进行,形成一个类似于 Transformer 的堆叠结构,但完全不使用注意力机制。
401+
402+
输入:
403+
- `G ∈ R^{L×d}`:分组后的文本引导特征
404+
405+
处理过程:
406+
407+
- 使用 MLP-Mixer 来更新分组特征,生成融合特征 F:
408+
409+
$$
410+
G' = G + \text{MLP}_1(G^T)^T
411+
$$
412+
$$
413+
F = G' + \text{MLP}_2(G')
414+
$$
415+
416+
其中:
417+
- `MLP₁` 负责组内信息混合(token 内部);
418+
- `MLP₂` 负责通道间信息混合(feature channel);
419+
- 两个 MLP 交替作用,实现跨模态信息的充分交互;
420+
- 最终输出融合特征 `F`
421+
422+
这部分代码实现如下:
423+
424+
```python
348425
# mixer 的实现
349426
class MLPMixerLayer(nn.Module):
350427
def __init__(self,
@@ -378,13 +455,107 @@ class MLPMixerLayer(nn.Module):
378455

379456
self.norm1 = nn.LayerNorm(embed_dims)
380457
self.norm2 = nn.LayerNorm(embed_dims)
381-
458+
459+
# x 分组后的文本引导特征 : (b,l,c)
382460
def forward(self, x):
383-
# x_mask = (x.sum(-1)!=0).to(x.dtype)
461+
# x 转置后: (b,c,l) , patch_mixer 负责组内信息混合(token 内部)
384462
x = x + self.patch_mixer(self.norm1(x).transpose(1,2)).transpose(1,2)
463+
# channel_mixer 负责通道间信息混合(feature channel)
385464
x = x + self.channel_mixer(self.norm2(x))
386-
# x *= x_mask
387465
return x
388466
```
467+
---
468+
469+
#### 3️⃣ Ungrouping:将融合特征映射回点空间
470+
471+
输入:
472+
473+
- 原始点特征 `P`
474+
- 融合后的文本特征 `F`
475+
476+
处理过程:
477+
478+
- 使用另一个注意力模块,将融合特征重新分配给每个点:
479+
480+
$$
481+
P_m = \text{Attention}(P, W_2F, F) + P
482+
$$
483+
484+
其中:
485+
- `W₂` 是线性变换;
486+
- 注意力机制让每个点从融合特征中提取相关信息;
487+
- 输出 `P_m` 是语言增强后的点特征;
488+
- 最后加上残差连接形成最终输出 `P_o`
489+
490+
$$
491+
P_o = P_m + \text{residual}
492+
$$
493+
494+
这个 `P_o` 就是经过 AFM 增强的点特征图,用于后续分割掩码预测。
495+
496+
```python
497+
class FullAttnCatBlock(nn.Module):
498+
# query 为点云: (b,n,c) , key和value为融合后的文本特征: (b,l,c)
499+
def forward(self, query, key, value, key_padding_mask=None):
500+
def _inner_forward(query, key, value, key_padding_mask):
501+
q = self.norm_query(query)
502+
k = q if self.key_is_query else self.norm_key(key)
503+
v = k if self.value_is_key else self.norm_value(value)
504+
505+
# 使用另一个注意力模块,将融合特征重新分配给每个点
506+
x = self.attn(q, k, v, key_padding_mask) + self.drop(query)
507+
# MLP映射 + Residual Connection
508+
x = self.ffn(self.norm2(x)) + x
509+
return x
510+
511+
return _inner_forward(query, key, value, key_padding_mask)
512+
```
513+
---
514+
#### 4️⃣ AFM 自适应融合模块
515+
516+
有了以上 Grouping - Mixing - Ungrouping 三个关键步骤的实现,下面只需要把以上的三个步骤按流程组织起来即可得到AFM模块的完整实现了:
517+
518+
```python
519+
class GPBlock(nn.Module):
520+
# q: 文本特征 (b, l, c) , x: 点集特征集合 (b, n, c)
521+
def forward(self, q, x, q_mask=None):
522+
# Grouping阶段
523+
gt = self.group_layer(query=q, key=x, value=x)
524+
if q_mask is not None:
525+
gt *= q_mask.unsqueeze(-1)
526+
# Mixing阶段
527+
gt = self.mixer(gt) + self.drop(gt)
528+
# Ungrouping阶段
529+
ungroup_tokens = self.un_group_layer(query=x, key=gt, value=gt, key_padding_mask=q_mask)
530+
return ungroup_tokens
531+
```
532+
533+
AFM 的网络结构可视化理解
534+
535+
```
536+
文本特征 X ──┐
537+
538+
Grouping (Cross-Attention)
539+
540+
Mixing (MLP-Mixer)
541+
542+
Ungrouping (Attention)
543+
544+
输出增强后的点特征 P_o
545+
```
546+
547+
- **Grouping**:用语言引导点特征分组;
548+
- **Mixing**:在分组内进行信息交换;
549+
- **Ungrouping**:再将融合信息返回点空间;
550+
551+
这种设计使得语言信息能有效地指导点特征的学习过程,论文中也进行了大量消融实验来验证 AFM 的有效性:
552+
553+
| 模型变体 | mIoU | AUC | SIM | MAE |
554+
|----------|-------|-----|-----|-----|
555+
| 基线(不加 AFM) | 17.7 | 82.1 | 0.558 | 0.110 |
556+
| 加入 AFM 后 | **20.8** | **87.3** | **0.629** | **0.093** |
557+
558+
结果表明:加入 AFM 显著提升了所有指标,说明其确实有效增强了语言-视觉的跨模态交互能力。
559+
389560

390561

0 commit comments

Comments
 (0)