@@ -674,3 +674,145 @@ _3daffordance = torch.sigmoid(_3daffordance)
674
674
- 在 token 维度求和(或平均池化),融合多个 token 的关注信息;
675
675
- 使用 sigmoid 得到最终的掩码,形状 ` (B, N) ` ;
676
676
- 每个点的值 ∈ [ 0, 1] ,表示其属于目标功能区域的概率;
677
+
678
+ ## 训练
679
+
680
+ 训练部分的核心代码实现如下:
681
+
682
+ ``` python
683
+ def main (opt , dict ):
684
+ # 1. 加载训练集,验证集,测试集
685
+ train_dataset = AffordQ(' train' )
686
+ train_loader = DataLoader(train_dataset, batch_size = batch_size, num_workers = 8 ,shuffle = True , drop_last = True )
687
+ val_dataset = AffordQ(' val' )
688
+ val_loader = DataLoader(val_dataset, batch_size = batch_size, num_workers = 8 , shuffle = False )
689
+ test_dataset = AffordQ(' test' )
690
+ test_loader = DataLoader(test_dataset, batch_size = batch_size, num_workers = 8 , shuffle = False )
691
+
692
+ # 2. 初始化模型
693
+ model = get_PointRefer(emb_dim = dict[' emb_dim' ],
694
+ proj_dim = dict[' proj_dim' ], num_heads = dict[' num_heads' ], N_raw = dict[' N_raw' ],
695
+ num_affordance = dict[' num_affordance' ], n_groups = opt.n_groups)
696
+
697
+ # 3. 初始化损失函数,优化器,学习率调度器
698
+ criterion_hm = HM_Loss()
699
+ criterion_ce = nn.CrossEntropyLoss()
700
+ param_dicts = [
701
+ {" params" : [p for n, p in model.named_parameters() if " text_encoder" not in n and p.requires_grad]},
702
+ {" params" : [p for n, p in model.named_parameters() if " text_encoder" in n and p.requires_grad], " lr" : opt.tlr}]
703
+ optimizer = torch.optim.Adam(params = param_dicts, lr = dict[' lr' ], betas = (0.9 , 0.999 ), eps = 1e-8 , weight_decay = opt.decay_rate)
704
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=dict['Epoch'], eta_min=1e-6)
705
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10 , gamma = 0.5 )
706
+
707
+ '''
708
+ Training
709
+ '''
710
+ for epoch in range (start_epoch+ 1 , dict[' Epoch' ]):
711
+ num_batches = len (train_loader)
712
+ loss_sum = 0
713
+ total_point = 0
714
+ model = model.train()
715
+
716
+ for i,(point, cls , gt_mask, question, aff_label) in enumerate (train_loader):
717
+ optimizer.zero_grad()
718
+ # 4. 前向传播过程
719
+ _3d = model(question, point)
720
+ # 5. 计算损失
721
+ loss_hm = criterion_hm(_3d, gt_mask)
722
+ # loss_ce = criterion_ce(logits, cls)
723
+
724
+ temp_loss = loss_hm # + opt.loss_cls*loss_ce
725
+ temp_loss.backward()
726
+ optimizer.step()
727
+
728
+ results = torch.zeros((len (val_dataset), 2048 , 1 ))
729
+ targets = torch.zeros((len (val_dataset), 2048 , 1 ))
730
+
731
+ '''
732
+ Evalization
733
+ '''
734
+ if ((epoch+ 1 )% 1 == 0 ):
735
+ num = 0
736
+ with torch.no_grad():
737
+ num_batches = len (val_loader)
738
+ val_loss_sum = 0
739
+ total_MAE = 0
740
+ total_point = 0
741
+ model = model.eval()
742
+ for i,(point, _, label, question,aff_label) in enumerate (val_loader):
743
+ point, label = point.float(), label.float()
744
+
745
+ _3d = model(question, point)
746
+ mae, point_nums = evaluating(_3d, label)
747
+ total_point += point_nums
748
+ # val_loss_sum += val_loss.item()
749
+ total_MAE += mae.item()
750
+ pred_num = _3d.shape[0 ]
751
+ # print(f'---val_loss | {val_loss.item()}')
752
+ results[num : num+ pred_num, :, :] = _3d.unsqueeze(- 1 )
753
+ targets[num : num+ pred_num, :, :] = label.unsqueeze(- 1 )
754
+ num += pred_num
755
+
756
+ # val_mean_loss = val_loss_sum / num_batches
757
+ # logger.debug(f'Epoch_{epoch} | val_loss | {val_mean_loss}')
758
+ mean_mae = total_MAE / total_point
759
+ results = results.detach().numpy()
760
+ targets = targets.detach().numpy()
761
+ # print(f'cuda memorry:{torch.cuda.memory_allocated(opt.gpu)/ (1024*1024)}')
762
+ SIM_matrix = np.zeros(targets.shape[0 ])
763
+ for i in range (targets.shape[0 ]):
764
+ SIM_matrix [i] = SIM(results[i], targets[i])
765
+
766
+ sim = np.mean(SIM_matrix )
767
+ AUC = np.zeros((targets.shape[0 ], targets.shape[2 ]))
768
+ IOU = np.zeros((targets.shape[0 ], targets.shape[2 ]))
769
+ IOU_thres = np.linspace(0 , 1 , 20 )
770
+ targets = targets >= 0.5
771
+ targets = targets.astype(int )
772
+ for i in range (AUC .shape[0 ]):
773
+ t_true = targets[i]
774
+ p_score = results[i]
775
+
776
+ if np.sum(t_true) == 0 :
777
+ AUC [i] = np.nan
778
+ IOU [i] = np.nan
779
+ else :
780
+ auc = roc_auc_score(t_true, p_score)
781
+ AUC [i] = auc
782
+
783
+ p_mask = (p_score > 0.5 ).astype(int )
784
+ temp_iou = []
785
+ for thre in IOU_thres :
786
+ p_mask = (p_score >= thre).astype(int )
787
+ intersect = np.sum(p_mask & t_true)
788
+ union = np.sum(p_mask | t_true)
789
+ temp_iou.append(1 .* intersect/ union)
790
+ temp_iou = np.array(temp_iou)
791
+ aiou = np.mean(temp_iou)
792
+ IOU [i] = aiou
793
+
794
+ AUC = np.nanmean(AUC )
795
+ IOU = np.nanmean(IOU )
796
+
797
+ logger.debug(f ' AUC: { AUC } | IOU: { IOU } | SIM: { sim} | MAE: { mean_mae} ' )
798
+
799
+ current_IOU = IOU
800
+ if (current_IOU > best_IOU):
801
+ best_IOU = current_IOU
802
+ best_model_path = save_path + ' /best_model-{} .pt' .format(sign)
803
+ checkpoint = {
804
+ ' model' : model.state_dict(),
805
+ ' optimizer' : optimizer.state_dict(),
806
+ ' Epoch' : epoch
807
+ }
808
+ torch.save(checkpoint, best_model_path)
809
+ logger.debug(f ' best model saved at { best_model_path} ' )
810
+ scheduler.step()
811
+ logger.debug(f ' Best Val IOU: { best_IOU} ' )
812
+
813
+ category_metrics, affordance_metrics, overall_metrics = evaluate(model, test_loader, device, 3 )
814
+ print_metrics_in_table(category_metrics, affordance_metrics, overall_metrics, logger)
815
+ ```
816
+
817
+ ### 损失函数
818
+
0 commit comments