Skip to content

Commit 469caa3

Browse files
authored
[RISCV] Use vwadd.vx for splat vector with extension (#87249)
This patch allows `combineBinOp_VLToVWBinOp_VL` to handle patterns like `(splat_vector (sext op))` or `(splat_vector (zext op))`. Then we can use `vwadd.vx` and `vwadd.w` for such a case. ### Source code ``` define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) { %sb = sext i32 %b to i64 %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0 %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64> %ve = add <vscale x 8 x i64> %vc, %splat ret <vscale x 8 x i64> %ve } ``` ### Before this patch [Compiler Explorer](https://godbolt.org/z/sq191PsT4) ``` vwadd_vx_splat_sext: sext.w a0, a0 vsetvli a1, zero, e64, m8, ta, ma vmv.v.x v16, a0 vsetvli zero, zero, e32, m4, ta, ma vwadd.wv v16, v16, v8 vmv8r.v v8, v16 ret ``` ### After this patch ``` vwadd_vx_splat_sext vsetvli a1, zero, e32, m4, ta, ma vwadd.vx v16, v8, a0 vmv8r.v v8, v16 ret ```
1 parent 313a33b commit 469caa3

File tree

5 files changed

+569
-235
lines changed

5 files changed

+569
-235
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13597,7 +13597,8 @@ struct NodeExtensionHelper {
1359713597

1359813598
/// Check if this instance represents a splat.
1359913599
bool isSplat() const {
13600-
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
13600+
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
13601+
OrigOperand.getOpcode() == ISD::SPLAT_VECTOR;
1360113602
}
1360213603

1360313604
/// Get the extended opcode.
@@ -13641,6 +13642,8 @@ struct NodeExtensionHelper {
1364113642
case RISCVISD::VZEXT_VL:
1364213643
case RISCVISD::FP_EXTEND_VL:
1364313644
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
13645+
case ISD::SPLAT_VECTOR:
13646+
return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
1364413647
case RISCVISD::VMV_V_X_VL:
1364513648
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
1364613649
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -13776,6 +13779,47 @@ struct NodeExtensionHelper {
1377613779
/// Check if this node needs to be fully folded or extended for all users.
1377713780
bool needToPromoteOtherUsers() const { return EnforceOneUse; }
1377813781

13782+
void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
13783+
const RISCVSubtarget &Subtarget) {
13784+
unsigned Opc = OrigOperand.getOpcode();
13785+
MVT VT = OrigOperand.getSimpleValueType();
13786+
13787+
assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
13788+
"Unexpected Opcode");
13789+
13790+
// The pasthru must be undef for tail agnostic.
13791+
if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
13792+
return;
13793+
13794+
// Get the scalar value.
13795+
SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
13796+
: OrigOperand.getOperand(1);
13797+
13798+
// See if we have enough sign bits or zero bits in the scalar to use a
13799+
// widening opcode by splatting to smaller element size.
13800+
unsigned EltBits = VT.getScalarSizeInBits();
13801+
unsigned ScalarBits = Op.getValueSizeInBits();
13802+
// Make sure we're getting all element bits from the scalar register.
13803+
// FIXME: Support implicit sign extension of vmv.v.x?
13804+
if (ScalarBits < EltBits)
13805+
return;
13806+
13807+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13808+
// If the narrow type cannot be expressed with a legal VMV,
13809+
// this is not a valid candidate.
13810+
if (NarrowSize < 8)
13811+
return;
13812+
13813+
if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
13814+
SupportsSExt = true;
13815+
13816+
if (DAG.MaskedValueIsZero(Op,
13817+
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
13818+
SupportsZExt = true;
13819+
13820+
EnforceOneUse = false;
13821+
}
13822+
1377913823
/// Helper method to set the various fields of this struct based on the
1378013824
/// type of \p Root.
1378113825
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -13814,43 +13858,10 @@ struct NodeExtensionHelper {
1381413858
case RISCVISD::FP_EXTEND_VL:
1381513859
SupportsFPExt = true;
1381613860
break;
13817-
case RISCVISD::VMV_V_X_VL: {
13818-
// Historically, we didn't care about splat values not disappearing during
13819-
// combines.
13820-
EnforceOneUse = false;
13821-
13822-
// The operand is a splat of a scalar.
13823-
13824-
// The pasthru must be undef for tail agnostic.
13825-
if (!OrigOperand.getOperand(0).isUndef())
13826-
break;
13827-
13828-
// Get the scalar value.
13829-
SDValue Op = OrigOperand.getOperand(1);
13830-
13831-
// See if we have enough sign bits or zero bits in the scalar to use a
13832-
// widening opcode by splatting to smaller element size.
13833-
MVT VT = Root->getSimpleValueType(0);
13834-
unsigned EltBits = VT.getScalarSizeInBits();
13835-
unsigned ScalarBits = Op.getValueSizeInBits();
13836-
// Make sure we're getting all element bits from the scalar register.
13837-
// FIXME: Support implicit sign extension of vmv.v.x?
13838-
if (ScalarBits < EltBits)
13839-
break;
13840-
13841-
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13842-
// If the narrow type cannot be expressed with a legal VMV,
13843-
// this is not a valid candidate.
13844-
if (NarrowSize < 8)
13845-
break;
13846-
13847-
if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
13848-
SupportsSExt = true;
13849-
if (DAG.MaskedValueIsZero(Op,
13850-
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
13851-
SupportsZExt = true;
13861+
case ISD::SPLAT_VECTOR:
13862+
case RISCVISD::VMV_V_X_VL:
13863+
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
1385213864
break;
13853-
}
1385413865
default:
1385513866
break;
1385613867
}

0 commit comments

Comments
 (0)