Skip to content

Commit a61c4a0

Browse files
committed
[RISCV][SelectionDAG] Lower shuffles as bitrotates with vror.vi when possible
Given a shuffle mask like <3, 0, 1, 2, 7, 4, 5, 6> for v8i8, we can reinterpret it as a shuffle of v2i32 where the two i32s are bit rotated, and lower it as a vror.vi (if legal with zvbb enabled). We also need to make sure that the larger element type is a valid SEW, hence the tests for zve32x. X86 already did this, so I've extracted the logic for it and put it inside ShuffleVectorSDNode so it could be reused by RISC-V. I originally tried to add this as a generic combine in DAGCombiner.cpp, but it ended up causing worse codegen on X86 and PPC. Reviewed By: reames, pengfei Differential Revision: https://reviews.llvm.org/D157417
1 parent 4a5bcbd commit a61c4a0

File tree

6 files changed

+944
-72
lines changed

6 files changed

+944
-72
lines changed

llvm/include/llvm/IR/Instructions.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,6 +2444,21 @@ class ShuffleVectorInst : public Instruction {
24442444
return isInterleaveMask(Mask, Factor, NumInputElts, StartIndexes);
24452445
}
24462446

2447+
/// Checks if the shuffle is a bit rotation of the first operand across
2448+
/// multiple subelements, e.g:
2449+
///
2450+
/// shuffle <8 x i8> %a, <8 x i8> poison, <8 x i32> <1, 0, 3, 2, 5, 4, 7, 6>
2451+
///
2452+
/// could be expressed as
2453+
///
2454+
/// rotl <4 x i16> %a, 8
2455+
///
2456+
/// If it can be expressed as a rotation, returns the number of subelements to
2457+
/// group by in NumSubElts and the number of bits to rotate left in RotateAmt.
2458+
static bool isBitRotateMask(ArrayRef<int> Mask, unsigned EltSizeInBits,
2459+
unsigned MinSubElts, unsigned MaxSubElts,
2460+
unsigned &NumSubElts, unsigned &RotateAmt);
2461+
24472462
// Methods for support type inquiry through isa, cast, and dyn_cast:
24482463
static bool classof(const Instruction *I) {
24492464
return I->getOpcode() == Instruction::ShuffleVector;

llvm/lib/IR/Instructions.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,6 +2833,45 @@ bool ShuffleVectorInst::isInterleaveMask(
28332833
return true;
28342834
}
28352835

2836+
/// Try to lower a vector shuffle as a bit rotation.
2837+
///
2838+
/// Look for a repeated rotation pattern in each sub group.
2839+
/// Returns an element-wise left bit rotation amount or -1 if failed.
2840+
static int matchShuffleAsBitRotate(ArrayRef<int> Mask, int NumSubElts) {
2841+
int NumElts = Mask.size();
2842+
assert((NumElts % NumSubElts) == 0 && "Illegal shuffle mask");
2843+
2844+
int RotateAmt = -1;
2845+
for (int i = 0; i != NumElts; i += NumSubElts) {
2846+
for (int j = 0; j != NumSubElts; ++j) {
2847+
int M = Mask[i + j];
2848+
if (M < 0)
2849+
continue;
2850+
if (M < i || M >= i + NumSubElts)
2851+
return -1;
2852+
int Offset = (NumSubElts - (M - (i + j))) % NumSubElts;
2853+
if (0 <= RotateAmt && Offset != RotateAmt)
2854+
return -1;
2855+
RotateAmt = Offset;
2856+
}
2857+
}
2858+
return RotateAmt;
2859+
}
2860+
2861+
bool ShuffleVectorInst::isBitRotateMask(
2862+
ArrayRef<int> Mask, unsigned EltSizeInBits, unsigned MinSubElts,
2863+
unsigned MaxSubElts, unsigned &NumSubElts, unsigned &RotateAmt) {
2864+
for (NumSubElts = MinSubElts; NumSubElts <= MaxSubElts; NumSubElts *= 2) {
2865+
int EltRotateAmt = matchShuffleAsBitRotate(Mask, NumSubElts);
2866+
if (EltRotateAmt < 0)
2867+
continue;
2868+
RotateAmt = EltRotateAmt * EltSizeInBits;
2869+
return true;
2870+
}
2871+
2872+
return false;
2873+
}
2874+
28362875
//===----------------------------------------------------------------------===//
28372876
// InsertValueInst Class
28382877
//===----------------------------------------------------------------------===//

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4260,6 +4260,51 @@ static SDValue lowerBitreverseShuffle(ShuffleVectorSDNode *SVN,
42604260
return Res;
42614261
}
42624262

4263+
// Given a shuffle mask like <3, 0, 1, 2, 7, 4, 5, 6> for v8i8, we can
4264+
// reinterpret it as a shuffle of v2i32 where the two i32s are bit rotated, and
4265+
// lower it as a vror.vi (if legal with zvbb enabled).
4266+
static SDValue lowerVECTOR_SHUFFLEAsRotate(ShuffleVectorSDNode *SVN,
4267+
SelectionDAG &DAG,
4268+
const RISCVSubtarget &Subtarget) {
4269+
SDLoc DL(SVN);
4270+
4271+
EVT VT = SVN->getValueType(0);
4272+
unsigned NumElts = VT.getVectorNumElements();
4273+
unsigned EltSizeInBits = VT.getScalarSizeInBits();
4274+
unsigned NumSubElts, RotateAmt;
4275+
if (!ShuffleVectorInst::isBitRotateMask(SVN->getMask(), EltSizeInBits, 2,
4276+
NumElts, NumSubElts, RotateAmt))
4277+
return SDValue();
4278+
MVT RotateVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits * NumSubElts),
4279+
NumElts / NumSubElts);
4280+
4281+
// We might have a RotateVT that isn't legal, e.g. v4i64 on zve32x.
4282+
if (!Subtarget.getTargetLowering()->isOperationLegalOrCustom(ISD::ROTL,
4283+
RotateVT))
4284+
return SDValue();
4285+
4286+
// If we just create the shift amount with
4287+
//
4288+
// DAG.getConstant(RotateAmt, DL, RotateVT)
4289+
//
4290+
// then for e64 we get a weird bitcasted build_vector on RV32 that we're
4291+
// unable to detect as a splat during pattern matching. So directly lower it
4292+
// to a vmv.v.x which gets matched to vror.vi.
4293+
MVT ContainerVT = getContainerForFixedLengthVector(DAG, RotateVT, Subtarget);
4294+
SDValue VL =
4295+
getDefaultVLOps(RotateVT, ContainerVT, DL, DAG, Subtarget).second;
4296+
SDValue RotateAmtSplat = DAG.getNode(
4297+
RISCVISD::VMV_V_X_VL, DL, ContainerVT, DAG.getUNDEF(ContainerVT),
4298+
DAG.getConstant(RotateAmt, DL, Subtarget.getXLenVT()), VL);
4299+
RotateAmtSplat =
4300+
convertFromScalableVector(RotateVT, RotateAmtSplat, DAG, Subtarget);
4301+
4302+
SDValue Rotate =
4303+
DAG.getNode(ISD::ROTL, DL, RotateVT,
4304+
DAG.getBitcast(RotateVT, SVN->getOperand(0)), RotateAmtSplat);
4305+
return DAG.getBitcast(VT, Rotate);
4306+
}
4307+
42634308
static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
42644309
const RISCVSubtarget &Subtarget) {
42654310
SDValue V1 = Op.getOperand(0);
@@ -4270,6 +4315,11 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
42704315
unsigned NumElts = VT.getVectorNumElements();
42714316
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
42724317

4318+
// Lower to a vror.vi of a larger element type if possible. Do this before we
4319+
// promote i1s to i8s.
4320+
if (SDValue V = lowerVECTOR_SHUFFLEAsRotate(SVN, DAG, Subtarget))
4321+
return V;
4322+
42734323
if (VT.getVectorElementType() == MVT::i1) {
42744324
if (SDValue V = lowerBitreverseShuffle(SVN, DAG, Subtarget))
42754325
return V;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10984,31 +10984,6 @@ static SDValue lowerShuffleAsDecomposedShuffleMerge(
1098410984
return DAG.getVectorShuffle(VT, DL, V1, V2, FinalMask);
1098510985
}
1098610986

10987-
/// Try to lower a vector shuffle as a bit rotation.
10988-
///
10989-
/// Look for a repeated rotation pattern in each sub group.
10990-
/// Returns a ISD::ROTL element rotation amount or -1 if failed.
10991-
static int matchShuffleAsBitRotate(ArrayRef<int> Mask, int NumSubElts) {
10992-
int NumElts = Mask.size();
10993-
assert((NumElts % NumSubElts) == 0 && "Illegal shuffle mask");
10994-
10995-
int RotateAmt = -1;
10996-
for (int i = 0; i != NumElts; i += NumSubElts) {
10997-
for (int j = 0; j != NumSubElts; ++j) {
10998-
int M = Mask[i + j];
10999-
if (M < 0)
11000-
continue;
11001-
if (!isInRange(M, i, i + NumSubElts))
11002-
return -1;
11003-
int Offset = (NumSubElts - (M - (i + j))) % NumSubElts;
11004-
if (0 <= RotateAmt && Offset != RotateAmt)
11005-
return -1;
11006-
RotateAmt = Offset;
11007-
}
11008-
}
11009-
return RotateAmt;
11010-
}
11011-
1101210987
static int matchShuffleAsBitRotate(MVT &RotateVT, int EltSizeInBits,
1101310988
const X86Subtarget &Subtarget,
1101410989
ArrayRef<int> Mask) {
@@ -11018,18 +10993,14 @@ static int matchShuffleAsBitRotate(MVT &RotateVT, int EltSizeInBits,
1101810993
// AVX512 only has vXi32/vXi64 rotates, so limit the rotation sub group size.
1101910994
int MinSubElts = Subtarget.hasAVX512() ? std::max(32 / EltSizeInBits, 2) : 2;
1102010995
int MaxSubElts = 64 / EltSizeInBits;
11021-
for (int NumSubElts = MinSubElts; NumSubElts <= MaxSubElts; NumSubElts *= 2) {
11022-
int RotateAmt = matchShuffleAsBitRotate(Mask, NumSubElts);
11023-
if (RotateAmt < 0)
11024-
continue;
11025-
11026-
int NumElts = Mask.size();
11027-
MVT RotateSVT = MVT::getIntegerVT(EltSizeInBits * NumSubElts);
11028-
RotateVT = MVT::getVectorVT(RotateSVT, NumElts / NumSubElts);
11029-
return RotateAmt * EltSizeInBits;
11030-
}
11031-
11032-
return -1;
10996+
unsigned RotateAmt, NumSubElts;
10997+
if (!ShuffleVectorInst::isBitRotateMask(Mask, EltSizeInBits, MinSubElts,
10998+
MaxSubElts, NumSubElts, RotateAmt))
10999+
return -1;
11000+
unsigned NumElts = Mask.size();
11001+
MVT RotateSVT = MVT::getIntegerVT(EltSizeInBits * NumSubElts);
11002+
RotateVT = MVT::getVectorVT(RotateSVT, NumElts / NumSubElts);
11003+
return RotateAmt;
1103311004
}
1103411005

1103511006
/// Lower shuffle using X86ISD::VROTLI rotations.

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-reverse.ll

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,19 @@ define <1 x i8> @reverse_v1i8(<1 x i8> %a) {
169169
}
170170

171171
define <2 x i8> @reverse_v2i8(<2 x i8> %a) {
172-
; CHECK-LABEL: reverse_v2i8:
173-
; CHECK: # %bb.0:
174-
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
175-
; CHECK-NEXT: vslidedown.vi v9, v8, 1
176-
; CHECK-NEXT: vslideup.vi v9, v8, 1
177-
; CHECK-NEXT: vmv1r.v v8, v9
178-
; CHECK-NEXT: ret
172+
; NO-ZVBB-LABEL: reverse_v2i8:
173+
; NO-ZVBB: # %bb.0:
174+
; NO-ZVBB-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
175+
; NO-ZVBB-NEXT: vslidedown.vi v9, v8, 1
176+
; NO-ZVBB-NEXT: vslideup.vi v9, v8, 1
177+
; NO-ZVBB-NEXT: vmv1r.v v8, v9
178+
; NO-ZVBB-NEXT: ret
179+
;
180+
; ZVBB-LABEL: reverse_v2i8:
181+
; ZVBB: # %bb.0:
182+
; ZVBB-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
183+
; ZVBB-NEXT: vror.vi v8, v8, 8
184+
; ZVBB-NEXT: ret
179185
%res = call <2 x i8> @llvm.experimental.vector.reverse.v2i8(<2 x i8> %a)
180186
ret <2 x i8> %res
181187
}
@@ -258,13 +264,19 @@ define <1 x i16> @reverse_v1i16(<1 x i16> %a) {
258264
}
259265

260266
define <2 x i16> @reverse_v2i16(<2 x i16> %a) {
261-
; CHECK-LABEL: reverse_v2i16:
262-
; CHECK: # %bb.0:
263-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
264-
; CHECK-NEXT: vslidedown.vi v9, v8, 1
265-
; CHECK-NEXT: vslideup.vi v9, v8, 1
266-
; CHECK-NEXT: vmv1r.v v8, v9
267-
; CHECK-NEXT: ret
267+
; NO-ZVBB-LABEL: reverse_v2i16:
268+
; NO-ZVBB: # %bb.0:
269+
; NO-ZVBB-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
270+
; NO-ZVBB-NEXT: vslidedown.vi v9, v8, 1
271+
; NO-ZVBB-NEXT: vslideup.vi v9, v8, 1
272+
; NO-ZVBB-NEXT: vmv1r.v v8, v9
273+
; NO-ZVBB-NEXT: ret
274+
;
275+
; ZVBB-LABEL: reverse_v2i16:
276+
; ZVBB: # %bb.0:
277+
; ZVBB-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
278+
; ZVBB-NEXT: vror.vi v8, v8, 16
279+
; ZVBB-NEXT: ret
268280
%res = call <2 x i16> @llvm.experimental.vector.reverse.v2i16(<2 x i16> %a)
269281
ret <2 x i16> %res
270282
}
@@ -332,13 +344,19 @@ define <1 x i32> @reverse_v1i32(<1 x i32> %a) {
332344
}
333345

334346
define <2 x i32> @reverse_v2i32(<2 x i32> %a) {
335-
; CHECK-LABEL: reverse_v2i32:
336-
; CHECK: # %bb.0:
337-
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
338-
; CHECK-NEXT: vslidedown.vi v9, v8, 1
339-
; CHECK-NEXT: vslideup.vi v9, v8, 1
340-
; CHECK-NEXT: vmv1r.v v8, v9
341-
; CHECK-NEXT: ret
347+
; NO-ZVBB-LABEL: reverse_v2i32:
348+
; NO-ZVBB: # %bb.0:
349+
; NO-ZVBB-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
350+
; NO-ZVBB-NEXT: vslidedown.vi v9, v8, 1
351+
; NO-ZVBB-NEXT: vslideup.vi v9, v8, 1
352+
; NO-ZVBB-NEXT: vmv1r.v v8, v9
353+
; NO-ZVBB-NEXT: ret
354+
;
355+
; ZVBB-LABEL: reverse_v2i32:
356+
; ZVBB: # %bb.0:
357+
; ZVBB-NEXT: vsetivli zero, 1, e64, m1, ta, ma
358+
; ZVBB-NEXT: vror.vi v8, v8, 32
359+
; ZVBB-NEXT: ret
342360
%res = call <2 x i32> @llvm.experimental.vector.reverse.v2i32(<2 x i32> %a)
343361
ret <2 x i32> %res
344362
}
@@ -572,13 +590,19 @@ define <1 x half> @reverse_v1f16(<1 x half> %a) {
572590
}
573591

574592
define <2 x half> @reverse_v2f16(<2 x half> %a) {
575-
; CHECK-LABEL: reverse_v2f16:
576-
; CHECK: # %bb.0:
577-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
578-
; CHECK-NEXT: vslidedown.vi v9, v8, 1
579-
; CHECK-NEXT: vslideup.vi v9, v8, 1
580-
; CHECK-NEXT: vmv1r.v v8, v9
581-
; CHECK-NEXT: ret
593+
; NO-ZVBB-LABEL: reverse_v2f16:
594+
; NO-ZVBB: # %bb.0:
595+
; NO-ZVBB-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
596+
; NO-ZVBB-NEXT: vslidedown.vi v9, v8, 1
597+
; NO-ZVBB-NEXT: vslideup.vi v9, v8, 1
598+
; NO-ZVBB-NEXT: vmv1r.v v8, v9
599+
; NO-ZVBB-NEXT: ret
600+
;
601+
; ZVBB-LABEL: reverse_v2f16:
602+
; ZVBB: # %bb.0:
603+
; ZVBB-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
604+
; ZVBB-NEXT: vror.vi v8, v8, 16
605+
; ZVBB-NEXT: ret
582606
%res = call <2 x half> @llvm.experimental.vector.reverse.v2f16(<2 x half> %a)
583607
ret <2 x half> %res
584608
}
@@ -646,13 +670,19 @@ define <1 x float> @reverse_v1f32(<1 x float> %a) {
646670
}
647671

648672
define <2 x float> @reverse_v2f32(<2 x float> %a) {
649-
; CHECK-LABEL: reverse_v2f32:
650-
; CHECK: # %bb.0:
651-
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
652-
; CHECK-NEXT: vslidedown.vi v9, v8, 1
653-
; CHECK-NEXT: vslideup.vi v9, v8, 1
654-
; CHECK-NEXT: vmv1r.v v8, v9
655-
; CHECK-NEXT: ret
673+
; NO-ZVBB-LABEL: reverse_v2f32:
674+
; NO-ZVBB: # %bb.0:
675+
; NO-ZVBB-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
676+
; NO-ZVBB-NEXT: vslidedown.vi v9, v8, 1
677+
; NO-ZVBB-NEXT: vslideup.vi v9, v8, 1
678+
; NO-ZVBB-NEXT: vmv1r.v v8, v9
679+
; NO-ZVBB-NEXT: ret
680+
;
681+
; ZVBB-LABEL: reverse_v2f32:
682+
; ZVBB: # %bb.0:
683+
; ZVBB-NEXT: vsetivli zero, 1, e64, m1, ta, ma
684+
; ZVBB-NEXT: vror.vi v8, v8, 32
685+
; ZVBB-NEXT: ret
656686
%res = call <2 x float> @llvm.experimental.vector.reverse.v2f32(<2 x float> %a)
657687
ret <2 x float> %res
658688
}

0 commit comments

Comments
 (0)