@@ -519,6 +519,121 @@ SDValue LoongArchTargetLowering::lowerBITREVERSE(SDValue Op,
519
519
}
520
520
}
521
521
522
+ // / Attempts to match a shuffle mask against the VBSLL, VBSRL, VSLLI and VSRLI
523
+ // / instruction.
524
+ // The funciton matches elements form one of the input vector shuffled to the
525
+ // left or right with zeroable elements 'shifted in'. It handles both the
526
+ // strictly bit-wise element shifts and the byte shfit across an entire 128-bit
527
+ // lane.
528
+ // This is mainly copy from X86.
529
+ static int matchShuffleAsShift (MVT &ShiftVT, unsigned &Opcode,
530
+ unsigned ScalarSizeInBits, ArrayRef<int > Mask,
531
+ int MaskOffset, const APInt &Zeroable) {
532
+ int Size = Mask.size ();
533
+ unsigned SizeInBits = Size * ScalarSizeInBits;
534
+
535
+ auto CheckZeros = [&](int Shift, int Scale, bool Left) {
536
+ for (int i = 0 ; i < Size; i += Scale)
537
+ for (int j = 0 ; j < Shift; ++j)
538
+ if (!Zeroable[i + j + (Left ? 0 : (Scale - Shift))])
539
+ return false ;
540
+
541
+ return true ;
542
+ };
543
+
544
+ auto isSequentialOrUndefInRange = [&](unsigned Pos, unsigned Size, int Low,
545
+ int Step = 1 ) {
546
+ for (unsigned i = Pos, e = Pos + Size; i != e; ++i, Low += Step)
547
+ if (!(Mask[i] == -1 || Mask[i] == Low))
548
+ return false ;
549
+ return true ;
550
+ };
551
+
552
+ auto MatchShift = [&](int Shift, int Scale, bool Left) {
553
+ for (int i = 0 ; i != Size; i += Scale) {
554
+ unsigned Pos = Left ? i + Shift : i;
555
+ unsigned Low = Left ? i : i + Shift;
556
+ unsigned Len = Scale - Shift;
557
+ if (!isSequentialOrUndefInRange (Pos, Len, Low + MaskOffset))
558
+ return -1 ;
559
+ }
560
+
561
+ int ShiftEltBits = ScalarSizeInBits * Scale;
562
+ bool ByteShift = ShiftEltBits > 64 ;
563
+ Opcode = Left ? (ByteShift ? LoongArchISD::VBSLL : LoongArchISD::VSLLI)
564
+ : (ByteShift ? LoongArchISD::VBSRL : LoongArchISD::VSRLI);
565
+ int ShiftAmt = Shift * ScalarSizeInBits / (ByteShift ? 8 : 1 );
566
+
567
+ // Normalize the scale for byte shifts to still produce an i64 element
568
+ // type.
569
+ Scale = ByteShift ? Scale / 2 : Scale;
570
+
571
+ // We need to round trip through the appropriate type for the shift.
572
+ MVT ShiftSVT = MVT::getIntegerVT (ScalarSizeInBits * Scale);
573
+ ShiftVT = ByteShift ? MVT::getVectorVT (MVT::i8 , SizeInBits / 8 )
574
+ : MVT::getVectorVT (ShiftSVT, Size / Scale);
575
+ return (int )ShiftAmt;
576
+ };
577
+
578
+ unsigned MaxWidth = 128 ;
579
+ for (int Scale = 2 ; Scale * ScalarSizeInBits <= MaxWidth; Scale *= 2 )
580
+ for (int Shift = 1 ; Shift != Scale; ++Shift)
581
+ for (bool Left : {true , false })
582
+ if (CheckZeros (Shift, Scale, Left)) {
583
+ int ShiftAmt = MatchShift (Shift, Scale, Left);
584
+ if (0 < ShiftAmt)
585
+ return ShiftAmt;
586
+ }
587
+
588
+ // no match
589
+ return -1 ;
590
+ }
591
+
592
+ // / Lower VECTOR_SHUFFLE as shift (if possible).
593
+ // /
594
+ // / For example:
595
+ // / %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
596
+ // / <4 x i32> <i32 4, i32 0, i32 1, i32 2>
597
+ // / is lowered to:
598
+ // / (VBSLL_V $v0, $v0, 4)
599
+ // /
600
+ // / %2 = shufflevector <4 x i32> %0, <4 x i32> zeroinitializer,
601
+ // / <4 x i32> <i32 4, i32 0, i32 4, i32 2>
602
+ // / is lowered to:
603
+ // / (VSLLI_D $v0, $v0, 32)
604
+ static SDValue lowerVECTOR_SHUFFLEAsShift (const SDLoc &DL, ArrayRef<int > Mask,
605
+ MVT VT, SDValue V1, SDValue V2,
606
+ SelectionDAG &DAG,
607
+ const APInt &Zeroable) {
608
+ int Size = Mask.size ();
609
+ assert (Size == (int )VT.getVectorNumElements () && " Unexpected mask size" );
610
+
611
+ MVT ShiftVT;
612
+ SDValue V = V1;
613
+ unsigned Opcode;
614
+
615
+ // Try to match shuffle against V1 shift.
616
+ int ShiftAmt = matchShuffleAsShift (ShiftVT, Opcode, VT.getScalarSizeInBits (),
617
+ Mask, 0 , Zeroable);
618
+
619
+ // If V1 failed, try to match shuffle against V2 shift.
620
+ if (ShiftAmt < 0 ) {
621
+ ShiftAmt = matchShuffleAsShift (ShiftVT, Opcode, VT.getScalarSizeInBits (),
622
+ Mask, Size, Zeroable);
623
+ V = V2;
624
+ }
625
+
626
+ if (ShiftAmt < 0 )
627
+ return SDValue ();
628
+
629
+ assert (DAG.getTargetLoweringInfo ().isTypeLegal (ShiftVT) &&
630
+ " Illegal integer vector type" );
631
+ V = DAG.getBitcast (ShiftVT, V);
632
+ V = DAG.getNode (Opcode, DL, ShiftVT, V,
633
+ DAG.getConstant (ShiftAmt, DL, MVT::i64 ));
634
+ return DAG.getBitcast (VT, V);
635
+ }
636
+
522
637
// / Determine whether a range fits a regular pattern of values.
523
638
// / This function accounts for the possibility of jumping over the End iterator.
524
639
template <typename ValType>
@@ -587,14 +702,12 @@ static void computeZeroableShuffleElements(ArrayRef<int> Mask, SDValue V1,
587
702
static SDValue lowerVECTOR_SHUFFLEAsZeroOrAnyExtend (const SDLoc &DL,
588
703
ArrayRef<int > Mask, MVT VT,
589
704
SDValue V1, SDValue V2,
590
- SelectionDAG &DAG) {
705
+ SelectionDAG &DAG,
706
+ const APInt &Zeroable) {
591
707
int Bits = VT.getSizeInBits ();
592
708
int EltBits = VT.getScalarSizeInBits ();
593
709
int NumElements = VT.getVectorNumElements ();
594
710
595
- APInt KnownUndef, KnownZero;
596
- computeZeroableShuffleElements (Mask, V1, V2, KnownUndef, KnownZero);
597
- APInt Zeroable = KnownUndef | KnownZero;
598
711
if (Zeroable.isAllOnes ())
599
712
return DAG.getConstant (0 , DL, VT);
600
713
@@ -1056,6 +1169,10 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
1056
1169
" Unexpected mask size for shuffle!" );
1057
1170
assert (Mask.size () % 2 == 0 && " Expected even mask size." );
1058
1171
1172
+ APInt KnownUndef, KnownZero;
1173
+ computeZeroableShuffleElements (Mask, V1, V2, KnownUndef, KnownZero);
1174
+ APInt Zeroable = KnownUndef | KnownZero;
1175
+
1059
1176
SDValue Result;
1060
1177
// TODO: Add more comparison patterns.
1061
1178
if (V2.isUndef ()) {
@@ -1083,12 +1200,14 @@ static SDValue lower128BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
1083
1200
return Result;
1084
1201
if ((Result = lowerVECTOR_SHUFFLE_VPICKOD (DL, Mask, VT, V1, V2, DAG)))
1085
1202
return Result;
1203
+ if ((Result = lowerVECTOR_SHUFFLEAsZeroOrAnyExtend (DL, Mask, VT, V1, V2, DAG,
1204
+ Zeroable)))
1205
+ return Result;
1086
1206
if ((Result =
1087
- lowerVECTOR_SHUFFLEAsZeroOrAnyExtend (DL, Mask, VT, V1, V2, DAG)))
1207
+ lowerVECTOR_SHUFFLEAsShift (DL, Mask, VT, V1, V2, DAG, Zeroable )))
1088
1208
return Result;
1089
1209
if ((Result = lowerVECTOR_SHUFFLE_VSHUF (DL, Mask, VT, V1, V2, DAG)))
1090
1210
return Result;
1091
-
1092
1211
return SDValue ();
1093
1212
}
1094
1213
@@ -4997,6 +5116,10 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
4997
5116
NODE_NAME_CASE (VANY_NONZERO)
4998
5117
NODE_NAME_CASE (FRECIPE)
4999
5118
NODE_NAME_CASE (FRSQRTE)
5119
+ NODE_NAME_CASE (VSLLI)
5120
+ NODE_NAME_CASE (VSRLI)
5121
+ NODE_NAME_CASE (VBSLL)
5122
+ NODE_NAME_CASE (VBSRL)
5000
5123
}
5001
5124
#undef NODE_NAME_CASE
5002
5125
return nullptr ;
0 commit comments