@@ -29,6 +29,66 @@ author:
29
29
30
30
## 模型结构
31
31
32
+ ### Fusion 多模态特征融合模块
33
+
34
+ ``python
35
+ class Fusion(nn.Module):
36
+ def __ init__ (self, emb_dim = 512, num_heads = 4):
37
+ super().__ init__ ()
38
+ self.emb_dim = emb_dim
39
+ # 对点积结果进行缩放,防止 softmax 梯度消失或爆炸。
40
+ self.div_scale = self.emb_dim ** (-0.5)
41
+ self.num_heads = num_heads
42
+
43
+ # 对图像和点云特征进行 非线性增强和空间对齐 ,使得它们能够在统一的语义空间中进行有效的跨模态交互。
44
+ self.mlp = nn.Sequential(
45
+ nn.Conv1d(self.emb_dim, 2* self.emb_dim, 1, 1),
46
+ nn.BatchNorm1d(2* self.emb_dim),
47
+ nn.ReLU(),
48
+ nn.Conv1d(2* self.emb_dim, self.emb_dim, 1, 1),
49
+ nn.BatchNorm1d(self.emb_dim),
50
+ nn.ReLU()
51
+ )
52
+
53
+ self.img_attention = Self_Attention(self.emb_dim, self.num_heads)
54
+ self.point_attention = Self_Attention(self.emb_dim, self.num_heads)
55
+ self.joint_attention = Self_Attention(self.emb_dim, self.num_heads)
56
+
57
+ def forward(self, img_feature, point_feature):
58
+ '''
59
+ i_feature: [B, C, H, W]
60
+ p_feature: [B, C, N_p]
61
+ HW = N_i
62
+ '''
63
+ B, C, H, W = img_feature.size()
64
+ img_feature = img_feature.view(B, self.emb_dim, -1) #[B, C, N_i]
65
+ point_feature = point_feature[-1][1]
66
+
67
+ # 对图像和点云特征进行 非线性增强和空间对齐 ,使得它们能够在统一的语义空间中进行有效的跨模态交互。
68
+ p_feature = self.mlp(point_feature)
69
+ i_feature = self.mlp(img_feature)
70
+
71
+ # 跨模态注意力矩阵: 每个点云点与图像中每个位置之间的相似度得分
72
+ phi = torch.bmm(p_feature.permute(0, 2, 1), i_feature)*self.div_scale #[B, N_p, N_i]
73
+ # 每列是一个 softmax 分布(每个图像位置对应的所有点云点), 表示:“对于图像中的每一个位置,应该关注哪些点云点?”
74
+ phi_p = F.softmax(phi,dim=1)
75
+ # 每行是一个 softmax 分布(每个点云点对应的所有图像位置), 表示:“对于点云中的每一个点,应该关注图像中的哪些位置?”
76
+ phi_i = F.softmax(phi,dim=-1)
77
+ # I_enhance 是图像 patch 引导下提取的点云信息增强后的图像特征
78
+ # 它不是直接包含原始图像 patch 的语义
79
+ # 而是通过“点云中相关点”的方式重构图像 patch 的语义
80
+ I_enhance = torch.bmm(p_feature, phi_p) #[B, C, N_i]
81
+ # P_enhance 是每个点云点引导下提取的图像信息增强后的点云点特征
82
+ P_enhance = torch.bmm(i_feature, phi_i.permute(0,2,1)) #[B, C, N_p]
83
+ I = self.img_attention(I_enhance.mT) #[B, N_i, C]
84
+ P = self.point_attention(P_enhance.mT) #[B, N_p, C]
85
+
86
+ joint_patch = torch.cat((P, I), dim=1)
87
+ multi_feature = self.joint_attention(joint_patch) #[B, N_p+N_i, C]
88
+
89
+ return multi_feature
90
+ ```
91
+
32
92
33
93
34
94
0 commit comments