@@ -319,32 +319,109 @@ class PointRefer(nn.Module):
319
319
320
320
### AFM 自适应融合模块
321
321
322
- ``` python
323
- class GPBlock (nn .Module ):
324
- # q: 文本特征 (b, l, c) , x: 点集特征集合 (b, n, c)
325
- def forward (self , q , x , q_mask = None ):
326
- # 1. 注意力: 文本特征作为query,从点集特征集合中提取相关区域特征
327
- gt = self .group_layer(query = q, key = x, value = x)
328
- if q_mask is not None :
329
- gt *= q_mask.unsqueeze(- 1 )
330
- # 2. MLP: 做token维度和channel维度的信息融合
331
- gt = self .mixer(gt) + self .drop(gt)
332
- ungroup_tokens = self .un_group_layer(query = x, key = gt, value = gt, key_padding_mask = q_mask)
333
- return ungroup_tokens
322
+ 在 LASO 任务中,模型需要根据自然语言问题(如 “Where to grasp?”)识别点云中的功能区域。由于目标功能区域的尺度、形状多样,传统方法难以适应不同情况。为此,作者设计了 AFM 模块,以增强 PointNet++ 解码过程中点特征的语言引导能力。
323
+
324
+ AFM 的目标是:在不同解码阶段注入语言线索(text clues),将文本语义信息与点云特征进行跨模态融合,逐步以自上而下的方式细化点特征图,从而提升模型对多尺度、多形状的功能区域的感知能力。
334
325
326
+ AFM 遵循一个 ** 瓶颈式架构(bottleneck architecture)** ,包含三个关键步骤:
327
+
328
+ 1 . ** Grouping(分组)**
329
+ 2 . ** Mixing(混合)**
330
+ 3 . ** Ungrouping(解组)**
331
+
332
+ 这三个步骤构成了一个完整的跨模态融合流程。
333
+
334
+ #### 1️⃣ Grouping:文本引导的点特征分组
335
+
336
+ 输入:
337
+ - ` X ∈ R^{L×d} ` :问题编码后的文本特征(由 RoBERTa 编码得到)
338
+ - ` P ∈ R^{T×d} ` :某一层解码器输出的点特征,其中 T 表示该层点数
339
+
340
+ 处理过程:
341
+ - 使用一个轻量级的交叉注意力模块,将文本特征作为查询(query),点特征作为键(key)和值(value),输出分组标记 G:
342
+
343
+ $$
344
+ G = \text{Attention}(X, W_1P, P) + X
345
+ $$
346
+ $$
347
+ \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
348
+ $$
349
+
350
+ 其中:
351
+ - $W_1$ 是一个线性变换;
352
+ - 注意力机制使得每个文本 token 对应一组相关的点特征;
353
+ - 分组操作实现了“语言引导的点特征筛选”。
354
+
355
+ > 重点是如何理解这里的分组: 每个文本Token询问所有点Key后,知道了哪些点跟自身的相关度更大,因此加权融合的时候,侧重于给这些点的特征分配更大的融合权重。
356
+
357
+ 这部分代码实现如下:
358
+
359
+ ``` python
335
360
# group_layer 的实现
336
361
class LightGroupAttnBlock (nn .Module ):
337
-
362
+
363
+ # query 是RoBerta编码后的文本特征 , (b,l,c)
364
+ # key和value都是点云特征 , (b,n,c)
338
365
def forward (self , query , key , value , q_mask = None ):
339
366
def _inner_forward (query , key , value ):
340
367
q = self .norm_query(query)
341
368
k = q if self .key_is_query else self .norm_key(key)
342
369
v = k if self .value_is_key else self .norm_value(value)
343
- x = self .attn(q, k, v, q_mask) + self .drop(q) # 残差连接
370
+ # 让每个语言 token 去关注点云中最相关的区域,并在此基础上强化自身的语义表达。
371
+ # 加上原始 X 是一种残差连接(Residual Connection),可以确保语言语义不会丢失。
372
+ x = self .attn(q, k, v, q_mask) + self .drop(q)
344
373
return x
345
-
374
+
346
375
return _inner_forward(query, key, value)
376
+ ```
377
+
378
+ ---
379
+
380
+ #### 2️⃣ Mixing:MLP-Mixer 进行组内和通道间的信息混合
381
+
382
+ MLP-Mixer 是一种 基于 MLP(多层感知机)的视觉模型架构 ,由 Google Research 在 2021 年提出。它不使用任何注意力机制,而是通过 空间混合(mixing)和通道混合(mixing)操作 来实现全局信息建模。
383
+
384
+ > [ MLP-Mixer: An all-MLP Architecture for Vision] ( https://arxiv.org/abs/2105.01601 )
385
+
386
+ MLP-Mixer 的核心思想是:用 MLP 替代 Transformer 中的自注意力机制 ,从而减少计算复杂度并保持性能。
387
+
388
+ 1 . Token-mixing MLP
347
389
390
+ - 对所有点/patch 的相同通道进行混合;
391
+ - 相当于跨空间位置的信息交换;
392
+ - 类似于 CNN 中的空间卷积;
393
+
394
+ 2 . Channel-mixing MLP
395
+
396
+ - 对每个 token 的所有通道进行处理;
397
+ - 提取更高级的特征表示;
398
+ - 类似于传统的全连接层或 1x1 卷积;
399
+
400
+ 这两个操作交替进行,形成一个类似于 Transformer 的堆叠结构,但完全不使用注意力机制。
401
+
402
+ 输入:
403
+ - ` G ∈ R^{L×d} ` :分组后的文本引导特征
404
+
405
+ 处理过程:
406
+
407
+ - 使用 MLP-Mixer 来更新分组特征,生成融合特征 F:
408
+
409
+ $$
410
+ G' = G + \text{MLP}_1(G^T)^T
411
+ $$
412
+ $$
413
+ F = G' + \text{MLP}_2(G')
414
+ $$
415
+
416
+ 其中:
417
+ - ` MLP₁ ` 负责组内信息混合(token 内部);
418
+ - ` MLP₂ ` 负责通道间信息混合(feature channel);
419
+ - 两个 MLP 交替作用,实现跨模态信息的充分交互;
420
+ - 最终输出融合特征 ` F ` ;
421
+
422
+ 这部分代码实现如下:
423
+
424
+ ``` python
348
425
# mixer 的实现
349
426
class MLPMixerLayer (nn .Module ):
350
427
def __init__ (self ,
@@ -378,13 +455,107 @@ class MLPMixerLayer(nn.Module):
378
455
379
456
self .norm1 = nn.LayerNorm(embed_dims)
380
457
self .norm2 = nn.LayerNorm(embed_dims)
381
-
458
+
459
+ # x 分组后的文本引导特征 : (b,l,c)
382
460
def forward (self , x ):
383
- # x_mask = (x.sum(-1)!=0).to(x.dtype)
461
+ # x 转置后: (b,c,l) , patch_mixer 负责组内信息混合(token 内部)
384
462
x = x + self .patch_mixer(self .norm1(x).transpose(1 ,2 )).transpose(1 ,2 )
463
+ # channel_mixer 负责通道间信息混合(feature channel)
385
464
x = x + self .channel_mixer(self .norm2(x))
386
- # x *= x_mask
387
465
return x
388
466
```
467
+ ---
468
+
469
+ #### 3️⃣ Ungrouping:将融合特征映射回点空间
470
+
471
+ 输入:
472
+
473
+ - 原始点特征 ` P ` ;
474
+ - 融合后的文本特征 ` F ` ;
475
+
476
+ 处理过程:
477
+
478
+ - 使用另一个注意力模块,将融合特征重新分配给每个点:
479
+
480
+ $$
481
+ P_m = \text{Attention}(P, W_2F, F) + P
482
+ $$
483
+
484
+ 其中:
485
+ - ` W₂ ` 是线性变换;
486
+ - 注意力机制让每个点从融合特征中提取相关信息;
487
+ - 输出 ` P_m ` 是语言增强后的点特征;
488
+ - 最后加上残差连接形成最终输出 ` P_o ` :
489
+
490
+ $$
491
+ P_o = P_m + \text{residual}
492
+ $$
493
+
494
+ 这个 ` P_o ` 就是经过 AFM 增强的点特征图,用于后续分割掩码预测。
495
+
496
+ ``` python
497
+ class FullAttnCatBlock (nn .Module ):
498
+ # query 为点云: (b,n,c) , key和value为融合后的文本特征: (b,l,c)
499
+ def forward (self , query , key , value , key_padding_mask = None ):
500
+ def _inner_forward (query , key , value , key_padding_mask ):
501
+ q = self .norm_query(query)
502
+ k = q if self .key_is_query else self .norm_key(key)
503
+ v = k if self .value_is_key else self .norm_value(value)
504
+
505
+ # 使用另一个注意力模块,将融合特征重新分配给每个点
506
+ x = self .attn(q, k, v, key_padding_mask) + self .drop(query)
507
+ # MLP映射 + Residual Connection
508
+ x = self .ffn(self .norm2(x)) + x
509
+ return x
510
+
511
+ return _inner_forward(query, key, value, key_padding_mask)
512
+ ```
513
+ ---
514
+ #### 4️⃣ AFM 自适应融合模块
515
+
516
+ 有了以上 Grouping - Mixing - Ungrouping 三个关键步骤的实现,下面只需要把以上的三个步骤按流程组织起来即可得到AFM模块的完整实现了:
517
+
518
+ ``` python
519
+ class GPBlock (nn .Module ):
520
+ # q: 文本特征 (b, l, c) , x: 点集特征集合 (b, n, c)
521
+ def forward (self , q , x , q_mask = None ):
522
+ # Grouping阶段
523
+ gt = self .group_layer(query = q, key = x, value = x)
524
+ if q_mask is not None :
525
+ gt *= q_mask.unsqueeze(- 1 )
526
+ # Mixing阶段
527
+ gt = self .mixer(gt) + self .drop(gt)
528
+ # Ungrouping阶段
529
+ ungroup_tokens = self .un_group_layer(query = x, key = gt, value = gt, key_padding_mask = q_mask)
530
+ return ungroup_tokens
531
+ ```
532
+
533
+ AFM 的网络结构可视化理解
534
+
535
+ ```
536
+ 文本特征 X ──┐
537
+ ↓
538
+ Grouping (Cross-Attention)
539
+ ↓
540
+ Mixing (MLP-Mixer)
541
+ ↓
542
+ Ungrouping (Attention)
543
+ ↓
544
+ 输出增强后的点特征 P_o
545
+ ```
546
+
547
+ - ** Grouping** :用语言引导点特征分组;
548
+ - ** Mixing** :在分组内进行信息交换;
549
+ - ** Ungrouping** :再将融合信息返回点空间;
550
+
551
+ 这种设计使得语言信息能有效地指导点特征的学习过程,论文中也进行了大量消融实验来验证 AFM 的有效性:
552
+
553
+ | 模型变体 | mIoU | AUC | SIM | MAE |
554
+ | ----------| -------| -----| -----| -----|
555
+ | 基线(不加 AFM) | 17.7 | 82.1 | 0.558 | 0.110 |
556
+ | 加入 AFM 后 | ** 20.8** | ** 87.3** | ** 0.629** | ** 0.093** |
557
+
558
+ 结果表明:加入 AFM 显著提升了所有指标,说明其确实有效增强了语言-视觉的跨模态交互能力。
559
+
389
560
390
561
0 commit comments