@@ -410,6 +410,158 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
410
410
MI.eraseFromParent ();
411
411
}
412
412
413
+ // Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
414
+ // Ensure that the type coming from the extend instruction is the right size
415
+ bool matchExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
416
+ std::pair<Register, bool > &MatchInfo) {
417
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
418
+ " Expected G_VECREDUCE_ADD Opcode" );
419
+
420
+ // Check if the last instruction is an extend
421
+ MachineInstr *ExtMI = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
422
+ auto ExtOpc = ExtMI->getOpcode ();
423
+
424
+ if (ExtOpc == TargetOpcode::G_ZEXT)
425
+ std::get<1 >(MatchInfo) = 0 ;
426
+ else if (ExtOpc == TargetOpcode::G_SEXT)
427
+ std::get<1 >(MatchInfo) = 1 ;
428
+ else
429
+ return false ;
430
+
431
+ // Check if the source register is a valid type
432
+ Register ExtSrcReg = ExtMI->getOperand (1 ).getReg ();
433
+ LLT ExtSrcTy = MRI.getType (ExtSrcReg);
434
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
435
+ if ((DstTy.getScalarSizeInBits () == 16 &&
436
+ ExtSrcTy.getNumElements () % 8 == 0 ) ||
437
+ (DstTy.getScalarSizeInBits () == 32 &&
438
+ ExtSrcTy.getNumElements () % 4 == 0 ) ||
439
+ (DstTy.getScalarSizeInBits () == 64 &&
440
+ ExtSrcTy.getNumElements () % 4 == 0 )) {
441
+ std::get<0 >(MatchInfo) = ExtSrcReg;
442
+ return true ;
443
+ }
444
+ return false ;
445
+ }
446
+
447
+ void applyExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
448
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
449
+ std::pair<Register, bool > &MatchInfo) {
450
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
451
+ " Expected G_VECREDUCE_ADD Opcode" );
452
+
453
+ unsigned Opc = std::get<1 >(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
454
+ Register SrcReg = std::get<0 >(MatchInfo);
455
+ Register DstReg = MI.getOperand (0 ).getReg ();
456
+ LLT SrcTy = MRI.getType (SrcReg);
457
+ LLT DstTy = MRI.getType (DstReg);
458
+
459
+ // If SrcTy has more elements than expected, split them into multiple
460
+ // insructions and sum the results
461
+ LLT MainTy;
462
+ SmallVector<Register, 1 > WorkingRegisters;
463
+ unsigned SrcScalSize = SrcTy.getScalarSizeInBits ();
464
+ unsigned SrcNumElem = SrcTy.getNumElements ();
465
+ if ((SrcScalSize == 8 && SrcNumElem > 16 ) ||
466
+ (SrcScalSize == 16 && SrcNumElem > 8 ) ||
467
+ (SrcScalSize == 32 && SrcNumElem > 4 )) {
468
+
469
+ LLT LeftoverTy;
470
+ SmallVector<Register, 4 > LeftoverRegs;
471
+ if (SrcScalSize == 8 )
472
+ MainTy = LLT::fixed_vector (16 , 8 );
473
+ else if (SrcScalSize == 16 )
474
+ MainTy = LLT::fixed_vector (8 , 16 );
475
+ else if (SrcScalSize == 32 )
476
+ MainTy = LLT::fixed_vector (4 , 32 );
477
+ else
478
+ llvm_unreachable (" Source's Scalar Size not supported" );
479
+
480
+ // Extract the parts and put each extracted sources through U/SADDLV and put
481
+ // the values inside a small vec
482
+ extractParts (SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
483
+ LeftoverRegs, B, MRI);
484
+ for (unsigned I = 0 ; I < LeftoverRegs.size (); I++) {
485
+ WorkingRegisters.push_back (LeftoverRegs[I]);
486
+ }
487
+ } else {
488
+ WorkingRegisters.push_back (SrcReg);
489
+ MainTy = SrcTy;
490
+ }
491
+
492
+ unsigned MidScalarSize = MainTy.getScalarSizeInBits () * 2 ;
493
+ LLT MidScalarLLT = LLT::scalar (MidScalarSize);
494
+ Register zeroReg =
495
+ B.buildConstant (LLT::scalar (64 ), 0 )->getOperand (0 ).getReg ();
496
+ for (unsigned I = 0 ; I < WorkingRegisters.size (); I++) {
497
+ // If the number of elements is too small to build an instruction, extend
498
+ // its size before applying addlv
499
+ LLT WorkingRegTy = MRI.getType (WorkingRegisters[I]);
500
+ if ((WorkingRegTy.getScalarSizeInBits () == 8 ) &&
501
+ (WorkingRegTy.getNumElements () == 4 )) {
502
+ WorkingRegisters[I] =
503
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
504
+ : TargetOpcode::G_ZEXT,
505
+ {LLT::fixed_vector (4 , 16 )}, {WorkingRegisters[I]})
506
+ ->getOperand (0 )
507
+ .getReg ();
508
+ }
509
+
510
+ // Generate the {U/S}ADDLV instruction, whose output is always double of the
511
+ // Src's Scalar size
512
+ LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector (4 , 32 )
513
+ : LLT::fixed_vector (2 , 64 );
514
+ Register addlvReg = B.buildInstr (Opc, {addlvTy}, {WorkingRegisters[I]})
515
+ ->getOperand (0 )
516
+ .getReg ();
517
+
518
+ // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
519
+ // v2i64 register.
520
+ // i16, i32 results uses v4i32 registers
521
+ // i64 results uses v2i64 registers
522
+ // Therefore we have to extract/truncate the the value to the right type
523
+ if (MidScalarSize == 32 || MidScalarSize == 64 ) {
524
+ WorkingRegisters[I] = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
525
+ {MidScalarLLT}, {addlvReg, zeroReg})
526
+ ->getOperand (0 )
527
+ .getReg ();
528
+ } else {
529
+ Register extractReg = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
530
+ {LLT::scalar (32 )}, {addlvReg, zeroReg})
531
+ ->getOperand (0 )
532
+ .getReg ();
533
+ WorkingRegisters[I] =
534
+ B.buildTrunc ({MidScalarLLT}, {extractReg})->getOperand (0 ).getReg ();
535
+ }
536
+ }
537
+
538
+ Register outReg;
539
+ if (WorkingRegisters.size () > 1 ) {
540
+ outReg = B.buildAdd (MidScalarLLT, WorkingRegisters[0 ], WorkingRegisters[1 ])
541
+ ->getOperand (0 )
542
+ .getReg ();
543
+ for (unsigned I = 2 ; I < WorkingRegisters.size (); I++) {
544
+ outReg = B.buildAdd (MidScalarLLT, outReg, WorkingRegisters[I])
545
+ ->getOperand (0 )
546
+ .getReg ();
547
+ }
548
+ } else {
549
+ outReg = WorkingRegisters[0 ];
550
+ }
551
+
552
+ if (DstTy.getScalarSizeInBits () > MidScalarSize) {
553
+ // Handle the scalar value if the DstTy's Scalar Size is more than double
554
+ // Src's ScalarType
555
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
556
+ : TargetOpcode::G_ZEXT,
557
+ {DstReg}, {outReg});
558
+ } else {
559
+ B.buildCopy (DstReg, outReg);
560
+ }
561
+
562
+ MI.eraseFromParent ();
563
+ }
564
+
413
565
bool tryToSimplifyUADDO (MachineInstr &MI, MachineIRBuilder &B,
414
566
CombinerHelper &Helper, GISelChangeObserver &Observer) {
415
567
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
0 commit comments