Skip to content

Commit ce8f094

Browse files
committed
[RISCV] Add patterns for vnsrl.vx where shift amount is truncated
Similar to D155698 where the shift amount is extended, this patch extends the ComplexPattern to handle the case where the shift amount has been truncated. Truncations are custom lowered to truncate_vector_vl, and in cases like i64 -> i16 they are truncated by one power of two at a time, so we need to unravel nested layers of them. The pattern can also be reused for Zvbb's vwsll.vx in an upcoming patch. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D155928
1 parent 7c652fe commit ce8f094

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,13 +3017,29 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
30173017
return true;
30183018
}
30193019

3020-
bool RISCVDAGToDAGISel::selectExtOneUseVSplat(SDValue N, SDValue &SplatVal) {
3021-
if (N->getOpcode() == ISD::SIGN_EXTEND ||
3022-
N->getOpcode() == ISD::ZERO_EXTEND) {
3023-
if (!N.hasOneUse())
3020+
bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
3021+
// Truncates are custom lowered during legalization.
3022+
auto IsTrunc = [this](SDValue N) {
3023+
if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
3024+
return false;
3025+
SDValue VL;
3026+
selectVLOp(N->getOperand(2), VL);
3027+
// Any vmset_vl is ok, since any bits past VL are undefined and we can
3028+
// assume they are set.
3029+
return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
3030+
isa<ConstantSDNode>(VL) &&
3031+
cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
3032+
};
3033+
3034+
// We can have multiple nested truncates, so unravel them all if needed.
3035+
while (N->getOpcode() == ISD::SIGN_EXTEND ||
3036+
N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
3037+
if (!N.hasOneUse() ||
3038+
N.getValueType().getSizeInBits().getKnownMinValue() < 8)
30243039
return false;
30253040
N = N->getOperand(0);
30263041
}
3042+
30273043
return selectVSplat(N, SplatVal);
30283044
}
30293045

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
134134
}
135135
bool selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal);
136136
bool selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal);
137-
bool selectExtOneUseVSplat(SDValue N, SDValue &SplatVal);
137+
// Matches the splat of a value which can be extended or truncated, such that
138+
// only the bottom 8 bits are preserved.
139+
bool selectLow8BitsVSplat(SDValue N, SDValue &SplatVal);
138140
bool selectFPImm(SDValue N, SDValue &Imm);
139141

140142
bool selectRVVSimm5(SDValue N, unsigned Width, SDValue &Imm);

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -577,8 +577,10 @@ def SplatPat_simm5_plus1
577577
def SplatPat_simm5_plus1_nonzero
578578
: ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1NonZero", [], [], 3>;
579579

580-
def ext_oneuse_SplatPat
581-
: ComplexPattern<vAny, 1, "selectExtOneUseVSplat", [], [], 2>;
580+
// Selects extends or truncates of splats where we only care about the lowest 8
581+
// bits of each element.
582+
def Low8BitsSplatPat
583+
: ComplexPattern<vAny, 1, "selectLow8BitsVSplat", [], [], 2>;
582584

583585
def SelectFPImm : ComplexPattern<fAny, 1, "selectFPImm", [], [], 1>;
584586

@@ -1453,7 +1455,7 @@ multiclass VPatBinaryVL_WV_WX_WI<SDNode op, string instruction_name> {
14531455
(vti.Vector
14541456
(riscv_trunc_vector_vl
14551457
(op (wti.Vector wti.RegClass:$rs2),
1456-
(wti.Vector (ext_oneuse_SplatPat (XLenVT GPR:$rs1)))),
1458+
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1)))),
14571459
(vti.Mask true_mask),
14581460
VLOpFrag)),
14591461
(!cast<Instruction>(instruction_name#"_WX_"#vti.LMul.MX)

llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -652,13 +652,8 @@ define <vscale x 1 x i16> @vnsrl_wx_i64_nxv1i16(<vscale x 1 x i32> %va, i64 %b)
652652
;
653653
; RV64-LABEL: vnsrl_wx_i64_nxv1i16:
654654
; RV64: # %bb.0:
655-
; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma
656-
; RV64-NEXT: vmv.v.x v9, a0
657-
; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
658-
; RV64-NEXT: vnsrl.wi v9, v9, 0
659-
; RV64-NEXT: vsrl.vv v8, v8, v9
660-
; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
661-
; RV64-NEXT: vnsrl.wi v8, v8, 0
655+
; RV64-NEXT: vsetvli a1, zero, e16, mf4, ta, ma
656+
; RV64-NEXT: vnsrl.wx v8, v8, a0
662657
; RV64-NEXT: ret
663658
%head = insertelement <vscale x 1 x i64> poison, i64 %b, i32 0
664659
%splat = shufflevector <vscale x 1 x i64> %head, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
@@ -689,15 +684,8 @@ define <vscale x 1 x i8> @vnsrl_wx_i64_nxv1i8(<vscale x 1 x i16> %va, i64 %b) {
689684
;
690685
; RV64-LABEL: vnsrl_wx_i64_nxv1i8:
691686
; RV64: # %bb.0:
692-
; RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma
693-
; RV64-NEXT: vmv.v.x v9, a0
694-
; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
695-
; RV64-NEXT: vnsrl.wi v9, v9, 0
696-
; RV64-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
697-
; RV64-NEXT: vnsrl.wi v9, v9, 0
698-
; RV64-NEXT: vsrl.vv v8, v8, v9
699-
; RV64-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
700-
; RV64-NEXT: vnsrl.wi v8, v8, 0
687+
; RV64-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
688+
; RV64-NEXT: vnsrl.wx v8, v8, a0
701689
; RV64-NEXT: ret
702690
%head = insertelement <vscale x 1 x i64> poison, i64 %b, i32 0
703691
%splat = shufflevector <vscale x 1 x i64> %head, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer

0 commit comments

Comments
 (0)