@@ -564,34 +564,33 @@ bool matchPushAddSubExt(
564
564
assert (MI.getOpcode () == TargetOpcode::G_ADD ||
565
565
MI.getOpcode () == TargetOpcode::G_SUB &&
566
566
" Expected a G_ADD or G_SUB instruction\n " );
567
- MachineInstr *ExtMI1 = MRI.getVRegDef (MI.getOperand (1 ).getReg ());
568
- MachineInstr *ExtMI2 = MRI.getVRegDef (MI.getOperand (2 ).getReg ());
569
567
570
- LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
568
+ // Deal with vector types only
569
+ get<1 >(matchinfo) = MI.getOperand (0 ).getReg ();
570
+ LLT DstTy = MRI.getType (get<1 >(matchinfo));
571
571
if (!DstTy.isVector ())
572
572
return false ;
573
573
574
- // Check the source came from G_{S/Z}EXT instructions
575
- if (ExtMI1->getOpcode () != ExtMI2->getOpcode () ||
576
- (ExtMI1->getOpcode () != TargetOpcode::G_SEXT &&
577
- ExtMI1->getOpcode () != TargetOpcode::G_ZEXT))
578
- return false ;
579
-
580
- if (!MRI.hasOneUse (ExtMI1->getOperand (0 ).getReg ()) ||
581
- !MRI.hasOneUse (ExtMI2->getOperand (0 ).getReg ()))
574
+ // Matching instruction pattern
575
+ Register Src1Reg = MI.getOperand (1 ).getReg ();
576
+ Register Src2Reg = MI.getOperand (2 ).getReg ();
577
+ bool ZExt =
578
+ mi_match (Src1Reg, MRI,
579
+ m_OneNonDBGUse (m_GZExt (m_Reg (get<2 >(matchinfo))))) &&
580
+ mi_match (Src2Reg, MRI, m_OneNonDBGUse (m_GZExt (m_Reg (get<3 >(matchinfo)))));
581
+ bool SExt =
582
+ mi_match (Src1Reg, MRI,
583
+ m_OneNonDBGUse (m_GSExt (m_Reg (get<2 >(matchinfo))))) &&
584
+ mi_match (Src2Reg, MRI, m_OneNonDBGUse (m_GSExt (m_Reg (get<3 >(matchinfo)))));
585
+ if (!ZExt && !SExt)
582
586
return false ;
587
+ get<0 >(matchinfo) = SExt;
583
588
584
589
// Return true if G_{S|Z}EXT instruction is more than 2* source
585
590
Register ExtDstReg = MI.getOperand (1 ).getReg ();
586
- get<0 >(matchinfo) = ExtMI1->getOpcode () == TargetOpcode::G_SEXT;
587
- get<1 >(matchinfo) = MI.getOperand (0 ).getReg ();
588
- get<2 >(matchinfo) = ExtMI1->getOperand (1 ).getReg ();
589
- get<3 >(matchinfo) = ExtMI2->getOperand (1 ).getReg ();
590
-
591
591
LLT ExtDstTy = MRI.getType (ExtDstReg);
592
592
LLT Ext1SrcTy = MRI.getType (get<2 >(matchinfo));
593
593
LLT Ext2SrcTy = MRI.getType (get<3 >(matchinfo));
594
-
595
594
if (((Ext1SrcTy.getScalarSizeInBits () == 8 &&
596
595
ExtDstTy.getScalarSizeInBits () == 32 ) ||
597
596
((Ext1SrcTy.getScalarSizeInBits () == 8 ||
0 commit comments