Skip to content

[RISCV] Remove RISCVISD::VNSRL_VL and adjust deinterleave lowering to match #118391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 21 additions & 44 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4618,51 +4618,31 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
// VT is the type of the vector to return, <[vscale x ]n x ty>
// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
bool EvenElts,
const RISCVSubtarget &Subtarget,
SelectionDAG &DAG) {
// The result is a vector of type <m x n x ty>
MVT ContainerVT = VT;
// Convert fixed vectors to scalable if needed
if (ContainerVT.isFixedLengthVector()) {
assert(Src.getSimpleValueType().isFixedLengthVector());
ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);

// The source is a vector of type <m x n*2 x ty> (For the single source
// case, the high half is undefined)
MVT SrcContainerVT =
MVT::getVectorVT(ContainerVT.getVectorElementType(),
ContainerVT.getVectorElementCount() * 2);
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
bool EvenElts, SelectionDAG &DAG) {
// The result is a vector of type <m x n x ty>. The source is a vector of
// type <m x n*2 x ty> (For the single source case, the high half is undef)
if (Src.getValueType() == VT) {
EVT WideVT = VT.getDoubleNumVectorElementsVT();
Src = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, DAG.getUNDEF(WideVT),
Src, DAG.getVectorIdxConstant(0, DL));
}

auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);

// Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
// This also converts FP to int.
unsigned EltBits = ContainerVT.getScalarSizeInBits();
MVT WideSrcContainerVT = MVT::getVectorVT(
MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
Src = DAG.getBitcast(WideSrcContainerVT, Src);
unsigned EltBits = VT.getScalarSizeInBits();
MVT WideSrcVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
VT.getVectorElementCount());
Src = DAG.getBitcast(WideSrcVT, Src);

// The integer version of the container type.
MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
MVT IntVT = VT.changeVectorElementTypeToInteger();

// If we want even elements, then the shift amount is 0. Otherwise, shift by
// the original element size.
unsigned Shift = EvenElts ? 0 : EltBits;
SDValue SplatShift = DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
SDValue Res =
DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
DAG.getUNDEF(IntContainerVT), TrueMask, VL);
// Cast back to FP if needed.
Res = DAG.getBitcast(ContainerVT, Res);

if (VT.isFixedLengthVector())
Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
return Res;
SDValue Res = DAG.getNode(ISD::SRL, DL, WideSrcVT, Src,
DAG.getConstant(Shift, DL, WideSrcVT));
Res = DAG.getNode(ISD::TRUNCATE, DL, IntVT, Res);
return DAG.getBitcast(VT, Res);
}

// Lower the following shuffle to vslidedown.
Expand Down Expand Up @@ -5356,7 +5336,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
// vnsrl to deinterleave.
if (SDValue Src =
isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, Subtarget, DAG);
return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, DAG);

if (SDValue V =
lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
Expand Down Expand Up @@ -6258,7 +6238,7 @@ static bool hasPassthruOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
128 &&
127 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand All @@ -6284,7 +6264,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
128 &&
127 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
Expand Down Expand Up @@ -10763,10 +10743,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
// We can deinterleave through vnsrl.wi if the element type is smaller than
// ELEN
if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
SDValue Even =
getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
SDValue Odd =
getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
SDValue Even = getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, DAG);
SDValue Odd = getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, DAG);
return DAG.getMergeValues({Even, Odd}, DL);
}

Expand Down Expand Up @@ -20494,7 +20472,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWMACC_VL)
NODE_NAME_CASE(VWMACCU_VL)
NODE_NAME_CASE(VWMACCSU_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VMERGE_VL)
NODE_NAME_CASE(VMAND_VL)
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,6 @@ enum NodeType : unsigned {
VWMACCU_VL,
VWMACCSU_VL,

// Narrowing logical shift right.
// Operands are (source, shift, passthru, mask, vl)
VNSRL_VL,

// Vector compare producing a mask. Fourth operand is input mask. Fifth
// operand is VL.
SETCC_VL,
Expand Down
36 changes: 0 additions & 36 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -459,17 +459,6 @@ def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNP
def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>;

def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisInt<1>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisOpSmallerThanOp<0, 1>,
SDTCisSameAs<0, 2>,
SDTCisSameAs<0, 3>,
SDTCisSameNumEltsAs<0, 4>,
SDTCVecEltisVT<4, i1>,
SDTCisVT<5, XLenVT>]>;
def riscv_vnsrl_vl : SDNode<"RISCVISD::VNSRL_VL", SDT_RISCVVNIntBinOp_VL>;

def SDT_RISCVVWIntBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisSameAs<0, 1>,
SDTCisInt<2>,
Expand Down Expand Up @@ -885,29 +874,6 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
}
}

multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name> {
foreach VtiToWti = AllWidenableIntVectors in {
defvar vti = VtiToWti.Vti;
defvar wti = VtiToWti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def : VPatBinaryVL_V<vop, instruction_name, "WV",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
vti.RegClass>;
def : VPatBinaryVL_XI<vop, instruction_name, "WX",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
SplatPat, GPR>;
def : VPatBinaryVL_XI<vop, instruction_name, "WI",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
!cast<ComplexPattern>(SplatPat#_#uimm5),
uimm5>;
}
}
}

class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
Expand Down Expand Up @@ -2166,8 +2132,6 @@ defm : VPatNarrowShiftSplatExt_WX<riscv_srl_vl, riscv_zext_vl_oneuse, "PseudoVNS
defm : VPatNarrowShiftVL_WV<riscv_srl_vl, "PseudoVNSRL">;
defm : VPatNarrowShiftVL_WV<riscv_sra_vl, "PseudoVNSRA">;

defm : VPatBinaryNVL_WV_WX_WI<riscv_vnsrl_vl, "PseudoVNSRL">;

foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
Expand Down
29 changes: 14 additions & 15 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,40 @@ define <4 x i32> @v4i32_v8i32(<8 x i32>) {
define <4 x i32> @v4i32_v16i32(<16 x i32>) {
; RV32-LABEL: v4i32_v16i32:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 8, e32, m4, ta, ma
; RV32-NEXT: vslidedown.vi v16, v8, 8
; RV32-NEXT: vmv4r.v v20, v8
; RV32-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; RV32-NEXT: vmv.v.i v8, 1
; RV32-NEXT: vmv2r.v v22, v12
; RV32-NEXT: vmv.v.i v10, 6
; RV32-NEXT: vmv.v.i v12, 1
; RV32-NEXT: vmv.v.i v14, 6
; RV32-NEXT: li a0, 32
; RV32-NEXT: vmv.v.i v0, 10
; RV32-NEXT: vsetivli zero, 2, e16, m1, tu, ma
; RV32-NEXT: vslideup.vi v10, v8, 1
; RV32-NEXT: vslideup.vi v14, v12, 1
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32-NEXT: vnsrl.wx v12, v8, a0
; RV32-NEXT: vsetivli zero, 8, e32, m4, ta, ma
; RV32-NEXT: vslidedown.vi v8, v8, 8
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV32-NEXT: vnsrl.wx v8, v20, a0
; RV32-NEXT: vrgatherei16.vv v8, v16, v10, v0.t
; RV32-NEXT: vrgatherei16.vv v12, v8, v14, v0.t
; RV32-NEXT: vmv1r.v v8, v12
; RV32-NEXT: ret
;
; RV64-LABEL: v4i32_v16i32:
; RV64: # %bb.0:
; RV64-NEXT: vsetivli zero, 8, e32, m4, ta, ma
; RV64-NEXT: vslidedown.vi v16, v8, 8
; RV64-NEXT: vmv4r.v v20, v8
; RV64-NEXT: li a0, 32
; RV64-NEXT: vmv2r.v v22, v12
; RV64-NEXT: vsetivli zero, 1, e8, mf8, ta, ma
; RV64-NEXT: vmv.v.i v0, 10
; RV64-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV64-NEXT: vnsrl.wx v8, v20, a0
; RV64-NEXT: vnsrl.wx v12, v8, a0
; RV64-NEXT: vsetivli zero, 8, e32, m4, ta, ma
; RV64-NEXT: vslidedown.vi v8, v8, 8
; RV64-NEXT: li a0, 3
; RV64-NEXT: slli a0, a0, 33
; RV64-NEXT: addi a0, a0, 1
; RV64-NEXT: slli a0, a0, 16
; RV64-NEXT: vsetivli zero, 2, e64, m1, ta, ma
; RV64-NEXT: vmv.v.x v10, a0
; RV64-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; RV64-NEXT: vrgatherei16.vv v8, v16, v10, v0.t
; RV64-NEXT: vrgatherei16.vv v12, v8, v10, v0.t
; RV64-NEXT: vmv1r.v v8, v12
; RV64-NEXT: ret
%2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 1, i32 9, i32 5, i32 14>
ret <4 x i32> %2
Expand Down
Loading