@@ -438,6 +438,122 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
438
438
MI.eraseFromParent ();
439
439
}
440
440
441
+ // Match mul({z/s}ext , {z/s}ext) => {u/s}mull
442
+ bool matchExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
443
+ GISelValueTracking *KB,
444
+ std::tuple<bool , Register, Register> &MatchInfo) {
445
+ // Get the instructions that defined the source operand
446
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
447
+ MachineInstr *I1 = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
448
+ MachineInstr *I2 = getDefIgnoringCopies (MI.getOperand (2 ).getReg (), MRI);
449
+ unsigned I1Opc = I1->getOpcode ();
450
+ unsigned I2Opc = I2->getOpcode ();
451
+ unsigned EltSize = DstTy.getScalarSizeInBits ();
452
+
453
+ if (!DstTy.isVector () || I1->getNumOperands () < 2 || I2->getNumOperands () < 2 )
454
+ return false ;
455
+
456
+ auto IsAtLeastDoubleExtend = [&](Register R) {
457
+ LLT Ty = MRI.getType (R);
458
+ return EltSize >= Ty.getScalarSizeInBits () * 2 ;
459
+ };
460
+
461
+ // If the source operands were EXTENDED before, then {U/S}MULL can be used
462
+ bool IsZExt1 =
463
+ I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
464
+ bool IsZExt2 =
465
+ I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
466
+ if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
467
+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
468
+ get<0 >(MatchInfo) = true ;
469
+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
470
+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
471
+ return true ;
472
+ }
473
+
474
+ bool IsSExt1 =
475
+ I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
476
+ bool IsSExt2 =
477
+ I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
478
+ if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend (I1->getOperand (1 ).getReg ()) &&
479
+ IsAtLeastDoubleExtend (I2->getOperand (1 ).getReg ())) {
480
+ get<0 >(MatchInfo) = false ;
481
+ get<1 >(MatchInfo) = I1->getOperand (1 ).getReg ();
482
+ get<2 >(MatchInfo) = I2->getOperand (1 ).getReg ();
483
+ return true ;
484
+ }
485
+
486
+ // Select UMULL if we can replace the other operand with an extend.
487
+ APInt Mask = APInt::getHighBitsSet (EltSize, EltSize / 2 );
488
+ if (KB && (IsZExt1 || IsZExt2) &&
489
+ IsAtLeastDoubleExtend (IsZExt1 ? I1->getOperand (1 ).getReg ()
490
+ : I2->getOperand (1 ).getReg ())) {
491
+ Register ZExtOp =
492
+ IsZExt1 ? MI.getOperand (2 ).getReg () : MI.getOperand (1 ).getReg ();
493
+ if (KB->maskedValueIsZero (ZExtOp, Mask)) {
494
+ get<0 >(MatchInfo) = true ;
495
+ get<1 >(MatchInfo) = IsZExt1 ? I1->getOperand (1 ).getReg () : ZExtOp;
496
+ get<2 >(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand (1 ).getReg ();
497
+ return true ;
498
+ }
499
+ } else if (KB && DstTy == LLT::fixed_vector (2 , 64 ) &&
500
+ KB->maskedValueIsZero (MI.getOperand (1 ).getReg (), Mask) &&
501
+ KB->maskedValueIsZero (MI.getOperand (2 ).getReg (), Mask)) {
502
+ get<0 >(MatchInfo) = true ;
503
+ get<1 >(MatchInfo) = MI.getOperand (1 ).getReg ();
504
+ get<2 >(MatchInfo) = MI.getOperand (2 ).getReg ();
505
+ return true ;
506
+ }
507
+
508
+ if (KB && (IsSExt1 || IsSExt2) &&
509
+ IsAtLeastDoubleExtend (IsSExt1 ? I1->getOperand (1 ).getReg ()
510
+ : I2->getOperand (1 ).getReg ())) {
511
+ Register SExtOp =
512
+ IsSExt1 ? MI.getOperand (2 ).getReg () : MI.getOperand (1 ).getReg ();
513
+ if (KB->computeNumSignBits (SExtOp) > EltSize / 2 ) {
514
+ get<0 >(MatchInfo) = false ;
515
+ get<1 >(MatchInfo) = IsSExt1 ? I1->getOperand (1 ).getReg () : SExtOp;
516
+ get<2 >(MatchInfo) = IsSExt1 ? SExtOp : I2->getOperand (1 ).getReg ();
517
+ return true ;
518
+ }
519
+ } else if (KB && DstTy == LLT::fixed_vector (2 , 64 ) &&
520
+ KB->computeNumSignBits (MI.getOperand (1 ).getReg ()) > EltSize / 2 &&
521
+ KB->computeNumSignBits (MI.getOperand (2 ).getReg ()) > EltSize / 2 ) {
522
+ get<0 >(MatchInfo) = false ;
523
+ get<1 >(MatchInfo) = MI.getOperand (1 ).getReg ();
524
+ get<2 >(MatchInfo) = MI.getOperand (2 ).getReg ();
525
+ return true ;
526
+ }
527
+
528
+ return false ;
529
+ }
530
+
531
+ void applyExtMulToMULL (MachineInstr &MI, MachineRegisterInfo &MRI,
532
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
533
+ std::tuple<bool , Register, Register> &MatchInfo) {
534
+ assert (MI.getOpcode () == TargetOpcode::G_MUL &&
535
+ " Expected a G_MUL instruction" );
536
+
537
+ // Get the instructions that defined the source operand
538
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
539
+ bool IsZExt = get<0 >(MatchInfo);
540
+ Register Src1Reg = get<1 >(MatchInfo);
541
+ Register Src2Reg = get<2 >(MatchInfo);
542
+ LLT Src1Ty = MRI.getType (Src1Reg);
543
+ LLT Src2Ty = MRI.getType (Src2Reg);
544
+ LLT HalfDstTy = DstTy.changeElementSize (DstTy.getScalarSizeInBits () / 2 );
545
+ unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
546
+
547
+ if (Src1Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
548
+ Src1Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src1Reg}).getReg (0 );
549
+ if (Src2Ty.getScalarSizeInBits () * 2 != DstTy.getScalarSizeInBits ())
550
+ Src2Reg = B.buildExtOrTrunc (ExtOpc, {HalfDstTy}, {Src2Reg}).getReg (0 );
551
+
552
+ B.buildInstr (IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
553
+ {MI.getOperand (0 ).getReg ()}, {Src1Reg, Src2Reg});
554
+ MI.eraseFromParent ();
555
+ }
556
+
441
557
class AArch64PostLegalizerCombinerImpl : public Combiner {
442
558
protected:
443
559
const CombinerHelper Helper;
0 commit comments