@@ -411,6 +411,158 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
411
411
MI.eraseFromParent ();
412
412
}
413
413
414
+ // Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
415
+ // Ensure that the type coming from the extend instruction is the right size
416
+ bool matchExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
417
+ std::pair<Register, bool > &MatchInfo) {
418
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
419
+ " Expected G_VECREDUCE_ADD Opcode" );
420
+
421
+ // Check if the last instruction is an extend
422
+ MachineInstr *ExtMI = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
423
+ auto ExtOpc = ExtMI->getOpcode ();
424
+
425
+ if (ExtOpc == TargetOpcode::G_ZEXT)
426
+ std::get<1 >(MatchInfo) = 0 ;
427
+ else if (ExtOpc == TargetOpcode::G_SEXT)
428
+ std::get<1 >(MatchInfo) = 1 ;
429
+ else
430
+ return false ;
431
+
432
+ // Check if the source register is a valid type
433
+ Register ExtSrcReg = ExtMI->getOperand (1 ).getReg ();
434
+ LLT ExtSrcTy = MRI.getType (ExtSrcReg);
435
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
436
+ if ((DstTy.getScalarSizeInBits () == 16 &&
437
+ ExtSrcTy.getNumElements () % 8 == 0 ) ||
438
+ (DstTy.getScalarSizeInBits () == 32 &&
439
+ ExtSrcTy.getNumElements () % 4 == 0 ) ||
440
+ (DstTy.getScalarSizeInBits () == 64 &&
441
+ ExtSrcTy.getNumElements () % 4 == 0 )) {
442
+ std::get<0 >(MatchInfo) = ExtSrcReg;
443
+ return true ;
444
+ }
445
+ return false ;
446
+ }
447
+
448
+ void applyExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
449
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
450
+ std::pair<Register, bool > &MatchInfo) {
451
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
452
+ " Expected G_VECREDUCE_ADD Opcode" );
453
+
454
+ unsigned Opc = std::get<1 >(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
455
+ Register SrcReg = std::get<0 >(MatchInfo);
456
+ Register DstReg = MI.getOperand (0 ).getReg ();
457
+ LLT SrcTy = MRI.getType (SrcReg);
458
+ LLT DstTy = MRI.getType (DstReg);
459
+
460
+ // If SrcTy has more elements than expected, split them into multiple
461
+ // insructions and sum the results
462
+ LLT MainTy;
463
+ SmallVector<Register, 1 > WorkingRegisters;
464
+ unsigned SrcScalSize = SrcTy.getScalarSizeInBits ();
465
+ unsigned SrcNumElem = SrcTy.getNumElements ();
466
+ if ((SrcScalSize == 8 && SrcNumElem > 16 ) ||
467
+ (SrcScalSize == 16 && SrcNumElem > 8 ) ||
468
+ (SrcScalSize == 32 && SrcNumElem > 4 )) {
469
+
470
+ LLT LeftoverTy;
471
+ SmallVector<Register, 4 > LeftoverRegs;
472
+ if (SrcScalSize == 8 )
473
+ MainTy = LLT::fixed_vector (16 , 8 );
474
+ else if (SrcScalSize == 16 )
475
+ MainTy = LLT::fixed_vector (8 , 16 );
476
+ else if (SrcScalSize == 32 )
477
+ MainTy = LLT::fixed_vector (4 , 32 );
478
+ else
479
+ llvm_unreachable (" Source's Scalar Size not supported" );
480
+
481
+ // Extract the parts and put each extracted sources through U/SADDLV and put
482
+ // the values inside a small vec
483
+ extractParts (SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
484
+ LeftoverRegs, B, MRI);
485
+ for (unsigned I = 0 ; I < LeftoverRegs.size (); I++) {
486
+ WorkingRegisters.push_back (LeftoverRegs[I]);
487
+ }
488
+ } else {
489
+ WorkingRegisters.push_back (SrcReg);
490
+ MainTy = SrcTy;
491
+ }
492
+
493
+ unsigned MidScalarSize = MainTy.getScalarSizeInBits () * 2 ;
494
+ LLT MidScalarLLT = LLT::scalar (MidScalarSize);
495
+ Register zeroReg =
496
+ B.buildConstant (LLT::scalar (64 ), 0 )->getOperand (0 ).getReg ();
497
+ for (unsigned I = 0 ; I < WorkingRegisters.size (); I++) {
498
+ // If the number of elements is too small to build an instruction, extend
499
+ // its size before applying addlv
500
+ LLT WorkingRegTy = MRI.getType (WorkingRegisters[I]);
501
+ if ((WorkingRegTy.getScalarSizeInBits () == 8 ) &&
502
+ (WorkingRegTy.getNumElements () == 4 )) {
503
+ WorkingRegisters[I] =
504
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
505
+ : TargetOpcode::G_ZEXT,
506
+ {LLT::fixed_vector (4 , 16 )}, {WorkingRegisters[I]})
507
+ ->getOperand (0 )
508
+ .getReg ();
509
+ }
510
+
511
+ // Generate the {U/S}ADDLV instruction, whose output is always double of the
512
+ // Src's Scalar size
513
+ LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector (4 , 32 )
514
+ : LLT::fixed_vector (2 , 64 );
515
+ Register addlvReg = B.buildInstr (Opc, {addlvTy}, {WorkingRegisters[I]})
516
+ ->getOperand (0 )
517
+ .getReg ();
518
+
519
+ // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
520
+ // v2i64 register.
521
+ // i16, i32 results uses v4i32 registers
522
+ // i64 results uses v2i64 registers
523
+ // Therefore we have to extract/truncate the the value to the right type
524
+ if (MidScalarSize == 32 || MidScalarSize == 64 ) {
525
+ WorkingRegisters[I] = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
526
+ {MidScalarLLT}, {addlvReg, zeroReg})
527
+ ->getOperand (0 )
528
+ .getReg ();
529
+ } else {
530
+ Register extractReg = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
531
+ {LLT::scalar (32 )}, {addlvReg, zeroReg})
532
+ ->getOperand (0 )
533
+ .getReg ();
534
+ WorkingRegisters[I] =
535
+ B.buildTrunc ({MidScalarLLT}, {extractReg})->getOperand (0 ).getReg ();
536
+ }
537
+ }
538
+
539
+ Register outReg;
540
+ if (WorkingRegisters.size () > 1 ) {
541
+ outReg = B.buildAdd (MidScalarLLT, WorkingRegisters[0 ], WorkingRegisters[1 ])
542
+ ->getOperand (0 )
543
+ .getReg ();
544
+ for (unsigned I = 2 ; I < WorkingRegisters.size (); I++) {
545
+ outReg = B.buildAdd (MidScalarLLT, outReg, WorkingRegisters[I])
546
+ ->getOperand (0 )
547
+ .getReg ();
548
+ }
549
+ } else {
550
+ outReg = WorkingRegisters[0 ];
551
+ }
552
+
553
+ if (DstTy.getScalarSizeInBits () > MidScalarSize) {
554
+ // Handle the scalar value if the DstTy's Scalar Size is more than double
555
+ // Src's ScalarType
556
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
557
+ : TargetOpcode::G_ZEXT,
558
+ {DstReg}, {outReg});
559
+ } else {
560
+ B.buildCopy (DstReg, outReg);
561
+ }
562
+
563
+ MI.eraseFromParent ();
564
+ }
565
+
414
566
bool tryToSimplifyUADDO (MachineInstr &MI, MachineIRBuilder &B,
415
567
CombinerHelper &Helper, GISelChangeObserver &Observer) {
416
568
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
0 commit comments