Skip to content

Commit 49745ce

Browse files
committed
[RISCV] use vwadd.vx for extended splat.
1 parent 684f27d commit 49745ce

File tree

2 files changed

+214
-39
lines changed

2 files changed

+214
-39
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13597,6 +13597,7 @@ struct NodeExtensionHelper {
1359713597
case RISCVISD::VSEXT_VL:
1359813598
case RISCVISD::VZEXT_VL:
1359913599
case RISCVISD::FP_EXTEND_VL:
13600+
case ISD::SPLAT_VECTOR:
1360013601
return OrigOperand.getOperand(0);
1360113602
default:
1360213603
return OrigOperand;
@@ -13605,7 +13606,8 @@ struct NodeExtensionHelper {
1360513606

1360613607
/// Check if this instance represents a splat.
1360713608
bool isSplat() const {
13608-
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
13609+
return (OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL) ||
13610+
(OrigOperand.getOpcode() == ISD::SPLAT_VECTOR);
1360913611
}
1361013612

1361113613
/// Get the extended opcode.
@@ -13649,6 +13651,8 @@ struct NodeExtensionHelper {
1364913651
case RISCVISD::VZEXT_VL:
1365013652
case RISCVISD::FP_EXTEND_VL:
1365113653
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
13654+
case ISD::SPLAT_VECTOR:
13655+
return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
1365213656
case RISCVISD::VMV_V_X_VL:
1365313657
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
1365413658
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
@@ -13781,6 +13785,57 @@ struct NodeExtensionHelper {
1378113785
/// Check if this node needs to be fully folded or extended for all users.
1378213786
bool needToPromoteOtherUsers() const { return EnforceOneUse; }
1378313787

13788+
void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
13789+
const RISCVSubtarget &Subtarget) {
13790+
unsigned Opc = OrigOperand.getOpcode();
13791+
MVT VT = OrigOperand.getSimpleValueType();
13792+
13793+
assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
13794+
"Unexpected Opcode");
13795+
13796+
if (Opc == ISD::SPLAT_VECTOR && !VT.isVector())
13797+
return;
13798+
13799+
// The pasthru must be undef for tail agnostic.
13800+
if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
13801+
return;
13802+
13803+
// Get the scalar value.
13804+
SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
13805+
: OrigOperand.getOperand(1);
13806+
13807+
// See if we have enough sign bits or zero bits in the scalar to use a
13808+
// widening opcode by splatting to smaller element size.
13809+
unsigned EltBits = VT.getScalarSizeInBits();
13810+
unsigned ScalarBits = Op.getValueSizeInBits();
13811+
// Make sure we're getting all element bits from the scalar register.
13812+
// FIXME: Support implicit sign extension of vmv.v.x?
13813+
if (ScalarBits < EltBits)
13814+
return;
13815+
13816+
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13817+
// If the narrow type cannot be expressed with a legal VMV,
13818+
// this is not a valid candidate.
13819+
if (NarrowSize < 8)
13820+
return;
13821+
13822+
if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
13823+
SupportsSExt = true;
13824+
13825+
if (DAG.MaskedValueIsZero(Op,
13826+
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
13827+
SupportsZExt = true;
13828+
13829+
EnforceOneUse = false;
13830+
CheckMask = Opc == ISD::SPLAT_VECTOR;
13831+
13832+
if (Opc == ISD::SPLAT_VECTOR)
13833+
std::tie(Mask, VL) =
13834+
getDefaultScalableVLOps(VT, SDLoc(Root), DAG, Subtarget);
13835+
else
13836+
VL = OrigOperand.getOperand(2);
13837+
}
13838+
1378413839
/// Helper method to set the various fields of this struct based on the
1378513840
/// type of \p Root.
1378613841
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
@@ -13826,45 +13881,10 @@ struct NodeExtensionHelper {
1382613881
Mask = OrigOperand.getOperand(1);
1382713882
VL = OrigOperand.getOperand(2);
1382813883
break;
13829-
case RISCVISD::VMV_V_X_VL: {
13830-
// Historically, we didn't care about splat values not disappearing during
13831-
// combines.
13832-
EnforceOneUse = false;
13833-
CheckMask = false;
13834-
VL = OrigOperand.getOperand(2);
13835-
13836-
// The operand is a splat of a scalar.
13837-
13838-
// The pasthru must be undef for tail agnostic.
13839-
if (!OrigOperand.getOperand(0).isUndef())
13840-
break;
13841-
13842-
// Get the scalar value.
13843-
SDValue Op = OrigOperand.getOperand(1);
13844-
13845-
// See if we have enough sign bits or zero bits in the scalar to use a
13846-
// widening opcode by splatting to smaller element size.
13847-
MVT VT = Root->getSimpleValueType(0);
13848-
unsigned EltBits = VT.getScalarSizeInBits();
13849-
unsigned ScalarBits = Op.getValueSizeInBits();
13850-
// Make sure we're getting all element bits from the scalar register.
13851-
// FIXME: Support implicit sign extension of vmv.v.x?
13852-
if (ScalarBits < EltBits)
13853-
break;
13854-
13855-
unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
13856-
// If the narrow type cannot be expressed with a legal VMV,
13857-
// this is not a valid candidate.
13858-
if (NarrowSize < 8)
13859-
break;
13860-
13861-
if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
13862-
SupportsSExt = true;
13863-
if (DAG.MaskedValueIsZero(Op,
13864-
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
13865-
SupportsZExt = true;
13884+
case ISD::SPLAT_VECTOR:
13885+
case RISCVISD::VMV_V_X_VL:
13886+
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
1386613887
break;
13867-
}
1386813888
default:
1386913889
break;
1387013890
}

llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,3 +1466,158 @@ define <vscale x 2 x i32> @vwadd_wv_disjoint_or(<vscale x 2 x i32> %x.i32, <vsca
14661466
%or = or disjoint <vscale x 2 x i32> %x.i32, %y.i32
14671467
ret <vscale x 2 x i32> %or
14681468
}
1469+
1470+
define <vscale x 8 x i64> @vwadd_vx_splat_zext(<vscale x 8 x i32> %va, i32 %b) {
1471+
; RV32-LABEL: vwadd_vx_splat_zext:
1472+
; RV32: # %bb.0:
1473+
; RV32-NEXT: addi sp, sp, -16
1474+
; RV32-NEXT: .cfi_def_cfa_offset 16
1475+
; RV32-NEXT: sw zero, 12(sp)
1476+
; RV32-NEXT: sw a0, 8(sp)
1477+
; RV32-NEXT: addi a0, sp, 8
1478+
; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, ma
1479+
; RV32-NEXT: vlse64.v v16, (a0), zero
1480+
; RV32-NEXT: vwaddu.wv v16, v16, v8
1481+
; RV32-NEXT: vmv8r.v v8, v16
1482+
; RV32-NEXT: addi sp, sp, 16
1483+
; RV32-NEXT: ret
1484+
;
1485+
; RV64-LABEL: vwadd_vx_splat_zext:
1486+
; RV64: # %bb.0:
1487+
; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
1488+
; RV64-NEXT: vwaddu.vx v16, v8, a0
1489+
; RV64-NEXT: vmv8r.v v8, v16
1490+
; RV64-NEXT: ret
1491+
%sb = zext i32 %b to i64
1492+
%head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
1493+
%splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
1494+
%vc = zext <vscale x 8 x i32> %va to <vscale x 8 x i64>
1495+
%ve = add <vscale x 8 x i64> %vc, %splat
1496+
ret <vscale x 8 x i64> %ve
1497+
}
1498+
1499+
define <vscale x 8 x i32> @vwadd_vx_splat_zext_i1(<vscale x 8 x i1> %va, i16 %b) {
1500+
; RV32-LABEL: vwadd_vx_splat_zext_i1:
1501+
; RV32: # %bb.0:
1502+
; RV32-NEXT: slli a0, a0, 16
1503+
; RV32-NEXT: srli a0, a0, 16
1504+
; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
1505+
; RV32-NEXT: vmv.v.x v8, a0
1506+
; RV32-NEXT: vadd.vi v8, v8, 1, v0.t
1507+
; RV32-NEXT: ret
1508+
;
1509+
; RV64-LABEL: vwadd_vx_splat_zext_i1:
1510+
; RV64: # %bb.0:
1511+
; RV64-NEXT: slli a0, a0, 48
1512+
; RV64-NEXT: srli a0, a0, 48
1513+
; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, mu
1514+
; RV64-NEXT: vmv.v.x v8, a0
1515+
; RV64-NEXT: vadd.vi v8, v8, 1, v0.t
1516+
; RV64-NEXT: ret
1517+
%sb = zext i16 %b to i32
1518+
%head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
1519+
%splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
1520+
%vc = zext <vscale x 8 x i1> %va to <vscale x 8 x i32>
1521+
%ve = add <vscale x 8 x i32> %vc, %splat
1522+
ret <vscale x 8 x i32> %ve
1523+
}
1524+
1525+
define <vscale x 8 x i64> @vwadd_wx_splat_zext(<vscale x 8 x i64> %va, i32 %b) {
1526+
; RV32-LABEL: vwadd_wx_splat_zext:
1527+
; RV32: # %bb.0:
1528+
; RV32-NEXT: addi sp, sp, -16
1529+
; RV32-NEXT: .cfi_def_cfa_offset 16
1530+
; RV32-NEXT: sw zero, 12(sp)
1531+
; RV32-NEXT: sw a0, 8(sp)
1532+
; RV32-NEXT: addi a0, sp, 8
1533+
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
1534+
; RV32-NEXT: vlse64.v v16, (a0), zero
1535+
; RV32-NEXT: vadd.vv v8, v8, v16
1536+
; RV32-NEXT: addi sp, sp, 16
1537+
; RV32-NEXT: ret
1538+
;
1539+
; RV64-LABEL: vwadd_wx_splat_zext:
1540+
; RV64: # %bb.0:
1541+
; RV64-NEXT: slli a0, a0, 32
1542+
; RV64-NEXT: srli a0, a0, 32
1543+
; RV64-NEXT: vsetvli a1, zero, e64, m8, ta, ma
1544+
; RV64-NEXT: vadd.vx v8, v8, a0
1545+
; RV64-NEXT: ret
1546+
%sb = zext i32 %b to i64
1547+
%head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
1548+
%splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
1549+
%ve = add <vscale x 8 x i64> %va, %splat
1550+
ret <vscale x 8 x i64> %ve
1551+
}
1552+
1553+
define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
1554+
; RV32-LABEL: vwadd_vx_splat_sext:
1555+
; RV32: # %bb.0:
1556+
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
1557+
; RV32-NEXT: vmv.v.x v16, a0
1558+
; RV32-NEXT: vsetvli zero, zero, e32, m4, ta, ma
1559+
; RV32-NEXT: vwadd.wv v16, v16, v8
1560+
; RV32-NEXT: vmv8r.v v8, v16
1561+
; RV32-NEXT: ret
1562+
;
1563+
; RV64-LABEL: vwadd_vx_splat_sext:
1564+
; RV64: # %bb.0:
1565+
; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, ma
1566+
; RV64-NEXT: vwadd.vx v16, v8, a0
1567+
; RV64-NEXT: vmv8r.v v8, v16
1568+
; RV64-NEXT: ret
1569+
%sb = sext i32 %b to i64
1570+
%head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
1571+
%splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
1572+
%vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
1573+
%ve = add <vscale x 8 x i64> %vc, %splat
1574+
ret <vscale x 8 x i64> %ve
1575+
}
1576+
1577+
define <vscale x 8 x i32> @vwadd_vx_splat_sext_i1(<vscale x 8 x i1> %va, i16 %b) {
1578+
; RV32-LABEL: vwadd_vx_splat_sext_i1:
1579+
; RV32: # %bb.0:
1580+
; RV32-NEXT: slli a0, a0, 16
1581+
; RV32-NEXT: srai a0, a0, 16
1582+
; RV32-NEXT: vsetvli a1, zero, e32, m4, ta, mu
1583+
; RV32-NEXT: vmv.v.x v8, a0
1584+
; RV32-NEXT: li a0, 1
1585+
; RV32-NEXT: vsub.vx v8, v8, a0, v0.t
1586+
; RV32-NEXT: ret
1587+
;
1588+
; RV64-LABEL: vwadd_vx_splat_sext_i1:
1589+
; RV64: # %bb.0:
1590+
; RV64-NEXT: slli a0, a0, 48
1591+
; RV64-NEXT: srai a0, a0, 48
1592+
; RV64-NEXT: vsetvli a1, zero, e32, m4, ta, mu
1593+
; RV64-NEXT: vmv.v.x v8, a0
1594+
; RV64-NEXT: li a0, 1
1595+
; RV64-NEXT: vsub.vx v8, v8, a0, v0.t
1596+
; RV64-NEXT: ret
1597+
%sb = sext i16 %b to i32
1598+
%head = insertelement <vscale x 8 x i32> poison, i32 %sb, i32 0
1599+
%splat = shufflevector <vscale x 8 x i32> %head, <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
1600+
%vc = sext <vscale x 8 x i1> %va to <vscale x 8 x i32>
1601+
%ve = add <vscale x 8 x i32> %vc, %splat
1602+
ret <vscale x 8 x i32> %ve
1603+
}
1604+
1605+
define <vscale x 8 x i64> @vwadd_wx_splat_sext(<vscale x 8 x i64> %va, i32 %b) {
1606+
; RV32-LABEL: vwadd_wx_splat_sext:
1607+
; RV32: # %bb.0:
1608+
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
1609+
; RV32-NEXT: vadd.vx v8, v8, a0
1610+
; RV32-NEXT: ret
1611+
;
1612+
; RV64-LABEL: vwadd_wx_splat_sext:
1613+
; RV64: # %bb.0:
1614+
; RV64-NEXT: sext.w a0, a0
1615+
; RV64-NEXT: vsetvli a1, zero, e64, m8, ta, ma
1616+
; RV64-NEXT: vadd.vx v8, v8, a0
1617+
; RV64-NEXT: ret
1618+
%sb = sext i32 %b to i64
1619+
%head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
1620+
%splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
1621+
%ve = add <vscale x 8 x i64> %va, %splat
1622+
ret <vscale x 8 x i64> %ve
1623+
}

0 commit comments

Comments
 (0)