@@ -410,6 +410,150 @@ void applyExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
410
410
MI.eraseFromParent ();
411
411
}
412
412
413
+ // Matches {U/S}ADDV(ext(x)) => {U/S}ADDLV(x)
414
+ // Ensure that the type coming from the extend instruction is the right size
415
+ bool matchExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
416
+ std::pair<Register, bool > &MatchInfo) {
417
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
418
+ " Expected G_VECREDUCE_ADD Opcode" );
419
+
420
+ // Check if the last instruction is an extend
421
+ MachineInstr *ExtMI = getDefIgnoringCopies (MI.getOperand (1 ).getReg (), MRI);
422
+ auto ExtOpc = ExtMI->getOpcode ();
423
+
424
+ if (ExtOpc == TargetOpcode::G_ZEXT)
425
+ std::get<1 >(MatchInfo) = 0 ;
426
+ else if (ExtOpc == TargetOpcode::G_SEXT)
427
+ std::get<1 >(MatchInfo) = 1 ;
428
+ else
429
+ return false ;
430
+
431
+ // Check if the source register is a valid type
432
+ Register ExtSrcReg = ExtMI->getOperand (1 ).getReg ();
433
+ LLT ExtSrcTy = MRI.getType (ExtSrcReg);
434
+ LLT DstTy = MRI.getType (MI.getOperand (0 ).getReg ());
435
+ if ((DstTy.getScalarSizeInBits () == 16 &&
436
+ ExtSrcTy.getNumElements () % 8 == 0 && ExtSrcTy.getNumElements () < 256 ) ||
437
+ (DstTy.getScalarSizeInBits () == 32 &&
438
+ ExtSrcTy.getNumElements () % 4 == 0 ) ||
439
+ (DstTy.getScalarSizeInBits () == 64 &&
440
+ ExtSrcTy.getNumElements () % 4 == 0 )) {
441
+ std::get<0 >(MatchInfo) = ExtSrcReg;
442
+ return true ;
443
+ }
444
+ return false ;
445
+ }
446
+
447
+ void applyExtUaddvToUaddlv (MachineInstr &MI, MachineRegisterInfo &MRI,
448
+ MachineIRBuilder &B, GISelChangeObserver &Observer,
449
+ std::pair<Register, bool > &MatchInfo) {
450
+ assert (MI.getOpcode () == TargetOpcode::G_VECREDUCE_ADD &&
451
+ " Expected G_VECREDUCE_ADD Opcode" );
452
+
453
+ unsigned Opc = std::get<1 >(MatchInfo) ? AArch64::G_SADDLV : AArch64::G_UADDLV;
454
+ Register SrcReg = std::get<0 >(MatchInfo);
455
+ Register DstReg = MI.getOperand (0 ).getReg ();
456
+ LLT SrcTy = MRI.getType (SrcReg);
457
+ LLT DstTy = MRI.getType (DstReg);
458
+
459
+ // If SrcTy has more elements than expected, split them into multiple
460
+ // insructions and sum the results
461
+ LLT MainTy;
462
+ SmallVector<Register, 1 > WorkingRegisters;
463
+ unsigned SrcScalSize = SrcTy.getScalarSizeInBits ();
464
+ unsigned SrcNumElem = SrcTy.getNumElements ();
465
+ if ((SrcScalSize == 8 && SrcNumElem > 16 ) ||
466
+ (SrcScalSize == 16 && SrcNumElem > 8 ) ||
467
+ (SrcScalSize == 32 && SrcNumElem > 4 )) {
468
+
469
+ LLT LeftoverTy;
470
+ SmallVector<Register, 4 > LeftoverRegs;
471
+ if (SrcScalSize == 8 )
472
+ MainTy = LLT::fixed_vector (16 , 8 );
473
+ else if (SrcScalSize == 16 )
474
+ MainTy = LLT::fixed_vector (8 , 16 );
475
+ else if (SrcScalSize == 32 )
476
+ MainTy = LLT::fixed_vector (4 , 32 );
477
+ else
478
+ llvm_unreachable (" Source's Scalar Size not supported" );
479
+
480
+ // Extract the parts and put each extracted sources through U/SADDLV and put
481
+ // the values inside a small vec
482
+ extractParts (SrcReg, SrcTy, MainTy, LeftoverTy, WorkingRegisters,
483
+ LeftoverRegs, B, MRI);
484
+ for (unsigned I = 0 ; I < LeftoverRegs.size (); I++) {
485
+ WorkingRegisters.push_back (LeftoverRegs[I]);
486
+ }
487
+ } else {
488
+ WorkingRegisters.push_back (SrcReg);
489
+ MainTy = SrcTy;
490
+ }
491
+
492
+ unsigned MidScalarSize = MainTy.getScalarSizeInBits () * 2 ;
493
+ LLT MidScalarLLT = LLT::scalar (MidScalarSize);
494
+ Register zeroReg = B.buildConstant (LLT::scalar (64 ), 0 ).getReg (0 );
495
+ for (unsigned I = 0 ; I < WorkingRegisters.size (); I++) {
496
+ // If the number of elements is too small to build an instruction, extend
497
+ // its size before applying addlv
498
+ LLT WorkingRegTy = MRI.getType (WorkingRegisters[I]);
499
+ if ((WorkingRegTy.getScalarSizeInBits () == 8 ) &&
500
+ (WorkingRegTy.getNumElements () == 4 )) {
501
+ WorkingRegisters[I] =
502
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
503
+ : TargetOpcode::G_ZEXT,
504
+ {LLT::fixed_vector (4 , 16 )}, {WorkingRegisters[I]})
505
+ .getReg (0 );
506
+ }
507
+
508
+ // Generate the {U/S}ADDLV instruction, whose output is always double of the
509
+ // Src's Scalar size
510
+ LLT addlvTy = MidScalarSize <= 32 ? LLT::fixed_vector (4 , 32 )
511
+ : LLT::fixed_vector (2 , 64 );
512
+ Register addlvReg =
513
+ B.buildInstr (Opc, {addlvTy}, {WorkingRegisters[I]}).getReg (0 );
514
+
515
+ // The output from {U/S}ADDLV gets placed in the lowest lane of a v4i32 or
516
+ // v2i64 register.
517
+ // i16, i32 results uses v4i32 registers
518
+ // i64 results uses v2i64 registers
519
+ // Therefore we have to extract/truncate the the value to the right type
520
+ if (MidScalarSize == 32 || MidScalarSize == 64 ) {
521
+ WorkingRegisters[I] = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
522
+ {MidScalarLLT}, {addlvReg, zeroReg})
523
+ .getReg (0 );
524
+ } else {
525
+ Register extractReg = B.buildInstr (AArch64::G_EXTRACT_VECTOR_ELT,
526
+ {LLT::scalar (32 )}, {addlvReg, zeroReg})
527
+ .getReg (0 );
528
+ WorkingRegisters[I] =
529
+ B.buildTrunc ({MidScalarLLT}, {extractReg}).getReg (0 );
530
+ }
531
+ }
532
+
533
+ Register outReg;
534
+ if (WorkingRegisters.size () > 1 ) {
535
+ outReg = B.buildAdd (MidScalarLLT, WorkingRegisters[0 ], WorkingRegisters[1 ])
536
+ .getReg (0 );
537
+ for (unsigned I = 2 ; I < WorkingRegisters.size (); I++) {
538
+ outReg = B.buildAdd (MidScalarLLT, outReg, WorkingRegisters[I]).getReg (0 );
539
+ }
540
+ } else {
541
+ outReg = WorkingRegisters[0 ];
542
+ }
543
+
544
+ if (DstTy.getScalarSizeInBits () > MidScalarSize) {
545
+ // Handle the scalar value if the DstTy's Scalar Size is more than double
546
+ // Src's ScalarType
547
+ B.buildInstr (std::get<1 >(MatchInfo) ? TargetOpcode::G_SEXT
548
+ : TargetOpcode::G_ZEXT,
549
+ {DstReg}, {outReg});
550
+ } else {
551
+ B.buildCopy (DstReg, outReg);
552
+ }
553
+
554
+ MI.eraseFromParent ();
555
+ }
556
+
413
557
bool tryToSimplifyUADDO (MachineInstr &MI, MachineIRBuilder &B,
414
558
CombinerHelper &Helper, GISelChangeObserver &Observer) {
415
559
// Try simplify G_UADDO with 8 or 16 bit operands to wide G_ADD and TBNZ if
0 commit comments