Skip to content

Commit 991e3e5

Browse files
committed
updates
1 parent c3a88c8 commit 991e3e5

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

src/3DVL/LASO.md

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,3 +674,145 @@ _3daffordance = torch.sigmoid(_3daffordance)
674674
- 在 token 维度求和(或平均池化),融合多个 token 的关注信息;
675675
- 使用 sigmoid 得到最终的掩码,形状 `(B, N)`
676676
- 每个点的值 ∈ [0, 1],表示其属于目标功能区域的概率;
677+
678+
## 训练
679+
680+
训练部分的核心代码实现如下:
681+
682+
```python
683+
def main(opt, dict):
684+
# 1. 加载训练集,验证集,测试集
685+
train_dataset = AffordQ('train')
686+
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8 ,shuffle=True, drop_last=True)
687+
val_dataset = AffordQ('val')
688+
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
689+
test_dataset = AffordQ('test')
690+
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
691+
692+
# 2. 初始化模型
693+
model = get_PointRefer(emb_dim=dict['emb_dim'],
694+
proj_dim=dict['proj_dim'], num_heads=dict['num_heads'], N_raw=dict['N_raw'],
695+
num_affordance = dict['num_affordance'], n_groups=opt.n_groups)
696+
697+
# 3. 初始化损失函数,优化器,学习率调度器
698+
criterion_hm = HM_Loss()
699+
criterion_ce = nn.CrossEntropyLoss()
700+
param_dicts = [
701+
{"params": [p for n, p in model.named_parameters() if "text_encoder" not in n and p.requires_grad]},
702+
{"params": [p for n, p in model.named_parameters() if "text_encoder" in n and p.requires_grad], "lr": opt.tlr}]
703+
optimizer = torch.optim.Adam(params = param_dicts, lr=dict['lr'], betas=(0.9, 0.999), eps=1e-8, weight_decay=opt.decay_rate)
704+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=dict['Epoch'], eta_min=1e-6)
705+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
706+
707+
'''
708+
Training
709+
'''
710+
for epoch in range(start_epoch+1, dict['Epoch']):
711+
num_batches = len(train_loader)
712+
loss_sum = 0
713+
total_point = 0
714+
model = model.train()
715+
716+
for i,(point, cls, gt_mask, question, aff_label) in enumerate(train_loader):
717+
optimizer.zero_grad()
718+
# 4. 前向传播过程
719+
_3d = model(question, point)
720+
# 5. 计算损失
721+
loss_hm = criterion_hm(_3d, gt_mask)
722+
# loss_ce = criterion_ce(logits, cls)
723+
724+
temp_loss = loss_hm # + opt.loss_cls*loss_ce
725+
temp_loss.backward()
726+
optimizer.step()
727+
728+
results = torch.zeros((len(val_dataset), 2048, 1))
729+
targets = torch.zeros((len(val_dataset), 2048, 1))
730+
731+
'''
732+
Evalization
733+
'''
734+
if((epoch+1)%1 == 0):
735+
num = 0
736+
with torch.no_grad():
737+
num_batches = len(val_loader)
738+
val_loss_sum = 0
739+
total_MAE = 0
740+
total_point = 0
741+
model = model.eval()
742+
for i,(point, _, label, question,aff_label) in enumerate(val_loader):
743+
point, label = point.float(), label.float()
744+
745+
_3d = model(question, point)
746+
mae, point_nums = evaluating(_3d, label)
747+
total_point += point_nums
748+
# val_loss_sum += val_loss.item()
749+
total_MAE += mae.item()
750+
pred_num = _3d.shape[0]
751+
# print(f'---val_loss | {val_loss.item()}')
752+
results[num : num+pred_num, :, :] = _3d.unsqueeze(-1)
753+
targets[num : num+pred_num, :, :] = label.unsqueeze(-1)
754+
num += pred_num
755+
756+
# val_mean_loss = val_loss_sum / num_batches
757+
# logger.debug(f'Epoch_{epoch} | val_loss | {val_mean_loss}')
758+
mean_mae = total_MAE / total_point
759+
results = results.detach().numpy()
760+
targets = targets.detach().numpy()
761+
# print(f'cuda memorry:{torch.cuda.memory_allocated(opt.gpu)/ (1024*1024)}')
762+
SIM_matrix = np.zeros(targets.shape[0])
763+
for i in range(targets.shape[0]):
764+
SIM_matrix[i] = SIM(results[i], targets[i])
765+
766+
sim = np.mean(SIM_matrix)
767+
AUC = np.zeros((targets.shape[0], targets.shape[2]))
768+
IOU = np.zeros((targets.shape[0], targets.shape[2]))
769+
IOU_thres = np.linspace(0, 1, 20)
770+
targets = targets >= 0.5
771+
targets = targets.astype(int)
772+
for i in range(AUC.shape[0]):
773+
t_true = targets[i]
774+
p_score = results[i]
775+
776+
if np.sum(t_true) == 0:
777+
AUC[i] = np.nan
778+
IOU[i] = np.nan
779+
else:
780+
auc = roc_auc_score(t_true, p_score)
781+
AUC[i] = auc
782+
783+
p_mask = (p_score > 0.5).astype(int)
784+
temp_iou = []
785+
for thre in IOU_thres:
786+
p_mask = (p_score >= thre).astype(int)
787+
intersect = np.sum(p_mask & t_true)
788+
union = np.sum(p_mask | t_true)
789+
temp_iou.append(1.*intersect/union)
790+
temp_iou = np.array(temp_iou)
791+
aiou = np.mean(temp_iou)
792+
IOU[i] = aiou
793+
794+
AUC = np.nanmean(AUC)
795+
IOU = np.nanmean(IOU)
796+
797+
logger.debug(f'AUC:{AUC} | IOU:{IOU} | SIM:{sim} | MAE:{mean_mae}')
798+
799+
current_IOU = IOU
800+
if(current_IOU > best_IOU):
801+
best_IOU = current_IOU
802+
best_model_path = save_path + '/best_model-{}.pt'.format(sign)
803+
checkpoint = {
804+
'model': model.state_dict(),
805+
'optimizer': optimizer.state_dict(),
806+
'Epoch': epoch
807+
}
808+
torch.save(checkpoint, best_model_path)
809+
logger.debug(f'best model saved at {best_model_path}')
810+
scheduler.step()
811+
logger.debug(f'Best Val IOU:{best_IOU}')
812+
813+
category_metrics, affordance_metrics, overall_metrics = evaluate(model, test_loader, device, 3)
814+
print_metrics_in_table(category_metrics, affordance_metrics, overall_metrics, logger)
815+
```
816+
817+
### 损失函数
818+

0 commit comments

Comments
 (0)