@@ -133,7 +133,7 @@ author:
133
133
134
134
### 7. 代码实现
135
135
136
- 数据集加载的核心代码实现如下:
136
+ ** 数据集初始化的核心代码实现如下: **
137
137
138
138
``` python
139
139
class AffordQ (Dataset ):
@@ -163,18 +163,22 @@ class AffordQ(Dataset):
163
163
with open (os.path.join(data_root, f ' objects_ { split} .pkl ' ), ' rb' ) as f:
164
164
self .objects = pickle.load(f)
165
165
166
- # 加载58种物体-功能组合的标注数据
166
+ # 加载58种物体-功能组合的标注数据 (数据组织形式,参考上文的 Affordance-Question数据可视化图)
167
167
self .question_df = pd.read_csv(os.path.join(data_root, ' Affordance-Question.csv' ))
168
168
169
- # sort anno by object class and affordance type
169
+ # sort anno by object class and affordance type -- 遍历标注数据列表
170
170
self .sort_anno = {}
171
171
for item in sorted (self .anno, key = lambda x : x[' class' ]):
172
+ # 获取当前样本的物体类别和物体信息值: 点云ID, 功能区域掩码, 功能类别
172
173
key = item[' class' ]
173
174
value = {' shape_id' : item[' shape_id' ], ' mask' : item[' mask' ], ' affordance' : item[' affordance' ]}
174
175
176
+ # 每种物体可以对应多种形状实例和功能类别
175
177
if key not in self .sort_anno:
178
+ # 如果当前物体类别不在排序后的字典中,直接添加
176
179
self .sort_anno[key] = [value]
177
180
else :
181
+ # 如果当前物体类别在排序后的字典中,将当前样本的物体信息值追加到对应列表中
178
182
self .sort_anno[key].append(value)
179
183
```
180
184
加载的标注数据中每个样本的组织形式如下:
@@ -185,6 +189,49 @@ class AffordQ(Dataset):
185
189
186
190
![ 标注数据组织形式] ( LASO/2.png )
187
191
192
+ ![ 点云数据组织形式] ( LASO/3.png )
193
+
194
+ ![ 每种物体可以对应多种形状实例和功能类别] ( LASO/4.png )
195
+
196
+ ** 获取样本的代码实现:**
197
+
198
+ ``` python
199
+ def __getitem__ (self , index ):
200
+ # 根据样本索引取出样本数据
201
+ data = self .anno[index]
202
+ # 获取当前样本对应的点云ID
203
+ shape_id = data[' shape_id' ]
204
+ # 获取当前样本对应的物体类别
205
+ cls = data[' class' ]
206
+ # 获取当前样本对应的功能类型
207
+ affordance = data[' affordance' ]
208
+ # 获取当前样本对应的功能区域掩码
209
+ gt_mask = data[' mask' ]
210
+ # 取出当前样本对应的点云数据 ,(2048,3)
211
+ point_set = self .objects[str (shape_id)]
212
+ # 对点云数据进行归一化处理,消除尺度差异
213
+ point_set,_,_ = pc_normalize(point_set)
214
+ # 对点云数据进行转置操作 ,(3,2048)
215
+ point_set = point_set.transpose()
216
+
217
+ #
218
+ question = self .find_rephrase(self .question_df, cls , affordance)
219
+ affordance = self .aff2idx[affordance]
220
+
221
+ return point_set, self .cls2idx[cls ], gt_mask, question, affordance
222
+
223
+ def find_rephrase (self , df , object_name , affordance ):
224
+ # 如果当前是训练模式,则从问题1~15中随机选择一个问题,否则固定返回问题0
225
+ qid = str (np.random.randint(1 , 15 )) if self .split == ' train' else ' 0'
226
+ qid = ' Question' + qid
227
+ #
228
+ result = df.loc[(df[' Object' ] == object_name) & (df[' Affordance' ] == affordance), [qid]]
229
+ if not result.empty:
230
+ # return result.index[0], result.iloc[0]['Rephrase']
231
+ return result.iloc[0 ][qid]
232
+ else :
233
+ raise NotImplementedError
234
+ ```
188
235
189
236
### 8. 总结
190
237
0 commit comments