Skip to content

Commit 3f83a69

Browse files
authored
[RISCV] Allow folding vmerge into masked ops when mask is the same (#97989)
We currently only fold a vmerge into a masked true operand if the vmerge has an all-ones mask, since we end up keeping the mask from the true operand. But if the masks are the same then we can still fold, because vmerge and true have the same passthru. If an element was masked off in the original vmerge, it will also be masked off in the resulting true, and will have the same passthru value. The motivation for this is to lower masked VP loads and stores with passthrus to masked RVV instructions. Normally you can express a masked RVV instruction with a mask undisturbed passthru via a combination of a VP op with an all-ones mask and a vp.merge. But for loads and stores you need the same mask on the VP op as well as the vp.merge.
1 parent d0f3943 commit 3f83a69

File tree

4 files changed

+55
-21
lines changed

4 files changed

+55
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3523,24 +3523,26 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
35233523
return false;
35243524
}
35253525

3526-
static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
3526+
// After ISel, a vector pseudo's mask will be copied to V0 via a CopyToReg
3527+
// that's glued to the pseudo. This tries to look up the value that was copied
3528+
// to V0.
3529+
static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
35273530
// Check that we're using V0 as a mask register.
35283531
if (!isa<RegisterSDNode>(MaskOp) ||
35293532
cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
3530-
return false;
3533+
return SDValue();
35313534

35323535
// The glued user defines V0.
35333536
const auto *Glued = GlueOp.getNode();
35343537

35353538
if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
3536-
return false;
3539+
return SDValue();
35373540

35383541
// Check that we're defining V0 as a mask register.
35393542
if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
35403543
cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
3541-
return false;
3544+
return SDValue();
35423545

3543-
// Check the instruction defining V0; it needs to be a VMSET pseudo.
35443546
SDValue MaskSetter = Glued->getOperand(2);
35453547

35463548
// Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
@@ -3549,6 +3551,15 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
35493551
MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
35503552
MaskSetter = MaskSetter->getOperand(0);
35513553

3554+
return MaskSetter;
3555+
}
3556+
3557+
static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
3558+
// Check the instruction defining V0; it needs to be a VMSET pseudo.
3559+
SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
3560+
if (!MaskSetter)
3561+
return false;
3562+
35523563
const auto IsVMSet = [](unsigned Opc) {
35533564
return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
35543565
Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
@@ -3755,12 +3766,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
37553766
return false;
37563767
}
37573768

3758-
// If True is masked then the vmerge must have an all 1s mask, since we're
3759-
// going to keep the mask from True.
3769+
// If True is masked then the vmerge must have either the same mask or an all
3770+
// 1s mask, since we're going to keep the mask from True.
37603771
if (IsMasked && Mask) {
37613772
// FIXME: Support mask agnostic True instruction which would have an
37623773
// undef merge operand.
3763-
if (!usesAllOnesMask(Mask, Glue))
3774+
SDValue TrueMask =
3775+
getMaskSetter(True->getOperand(Info->MaskOpIdx),
3776+
True->getOperand(True->getNumOperands() - 1));
3777+
assert(TrueMask);
3778+
if (!usesAllOnesMask(Mask, Glue) && getMaskSetter(Mask, Glue) != TrueMask)
37643779
return false;
37653780
}
37663781

llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,14 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
240240
%b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 %1)
241241
ret <vscale x 2 x i32> %b
242242
}
243+
244+
define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
245+
; CHECK-LABEL: vpmerge_vadd_same_mask:
246+
; CHECK: # %bb.0:
247+
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu
248+
; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t
249+
; CHECK-NEXT: ret
250+
%a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
251+
%b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
252+
ret <vscale x 2 x i32> %b
253+
}

llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,8 @@ define <vscale x 1 x float> @vfmacc_vv_nxv1f32_tu(<vscale x 1 x half> %a, <vscal
8585
define <vscale x 1 x float> @vfmacc_vv_nxv1f32_masked__tu(<vscale x 1 x half> %a, <vscale x 1 x half> %b, <vscale x 1 x float> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
8686
; ZVFH-LABEL: vfmacc_vv_nxv1f32_masked__tu:
8787
; ZVFH: # %bb.0:
88-
; ZVFH-NEXT: vmv1r.v v11, v10
89-
; ZVFH-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
90-
; ZVFH-NEXT: vfwmacc.vv v11, v8, v9, v0.t
91-
; ZVFH-NEXT: vsetvli zero, zero, e32, mf2, tu, ma
92-
; ZVFH-NEXT: vmerge.vvm v10, v10, v11, v0
88+
; ZVFH-NEXT: vsetvli zero, a0, e16, mf4, tu, mu
89+
; ZVFH-NEXT: vfwmacc.vv v10, v8, v9, v0.t
9390
; ZVFH-NEXT: vmv1r.v v8, v10
9491
; ZVFH-NEXT: ret
9592
;

llvm/test/CodeGen/RISCV/rvv/vpload.ll

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
2626
ret <vscale x 1 x i8> %load
2727
}
2828

29+
define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
30+
; CHECK-LABEL: vpload_nxv1i8_passthru:
31+
; CHECK: # %bb.0:
32+
; CHECK-NEXT: vsetvli zero, a1, e8, mf8, tu, mu
33+
; CHECK-NEXT: vle8.v v8, (a0), v0.t
34+
; CHECK-NEXT: ret
35+
%load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
36+
%merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)
37+
ret <vscale x 1 x i8> %merge
38+
}
39+
2940
declare <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr, <vscale x 2 x i1>, i32)
3041

3142
define <vscale x 2 x i8> @vpload_nxv2i8(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
@@ -450,10 +461,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
450461
; CHECK-NEXT: add a4, a0, a4
451462
; CHECK-NEXT: vsetvli zero, a3, e64, m8, ta, ma
452463
; CHECK-NEXT: vle64.v v16, (a4), v0.t
453-
; CHECK-NEXT: bltu a1, a2, .LBB37_2
464+
; CHECK-NEXT: bltu a1, a2, .LBB38_2
454465
; CHECK-NEXT: # %bb.1:
455466
; CHECK-NEXT: mv a1, a2
456-
; CHECK-NEXT: .LBB37_2:
467+
; CHECK-NEXT: .LBB38_2:
457468
; CHECK-NEXT: vmv1r.v v0, v8
458469
; CHECK-NEXT: vsetvli zero, a1, e64, m8, ta, ma
459470
; CHECK-NEXT: vle64.v v8, (a0), v0.t
@@ -480,10 +491,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
480491
; CHECK-NEXT: slli a5, a3, 1
481492
; CHECK-NEXT: vmv1r.v v8, v0
482493
; CHECK-NEXT: mv a4, a2
483-
; CHECK-NEXT: bltu a2, a5, .LBB38_2
494+
; CHECK-NEXT: bltu a2, a5, .LBB39_2
484495
; CHECK-NEXT: # %bb.1:
485496
; CHECK-NEXT: mv a4, a5
486-
; CHECK-NEXT: .LBB38_2:
497+
; CHECK-NEXT: .LBB39_2:
487498
; CHECK-NEXT: sub a6, a4, a3
488499
; CHECK-NEXT: sltu a7, a4, a6
489500
; CHECK-NEXT: addi a7, a7, -1
@@ -499,21 +510,21 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
499510
; CHECK-NEXT: sltu a2, a2, a5
500511
; CHECK-NEXT: addi a2, a2, -1
501512
; CHECK-NEXT: and a2, a2, a5
502-
; CHECK-NEXT: bltu a2, a3, .LBB38_4
513+
; CHECK-NEXT: bltu a2, a3, .LBB39_4
503514
; CHECK-NEXT: # %bb.3:
504515
; CHECK-NEXT: mv a2, a3
505-
; CHECK-NEXT: .LBB38_4:
516+
; CHECK-NEXT: .LBB39_4:
506517
; CHECK-NEXT: slli a5, a3, 4
507518
; CHECK-NEXT: srli a6, a3, 2
508519
; CHECK-NEXT: vsetvli a7, zero, e8, mf2, ta, ma
509520
; CHECK-NEXT: vslidedown.vx v0, v8, a6
510521
; CHECK-NEXT: add a5, a0, a5
511522
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
512523
; CHECK-NEXT: vle64.v v24, (a5), v0.t
513-
; CHECK-NEXT: bltu a4, a3, .LBB38_6
524+
; CHECK-NEXT: bltu a4, a3, .LBB39_6
514525
; CHECK-NEXT: # %bb.5:
515526
; CHECK-NEXT: mv a4, a3
516-
; CHECK-NEXT: .LBB38_6:
527+
; CHECK-NEXT: .LBB39_6:
517528
; CHECK-NEXT: vmv1r.v v0, v8
518529
; CHECK-NEXT: vsetvli zero, a4, e64, m8, ta, ma
519530
; CHECK-NEXT: vle64.v v8, (a0), v0.t

0 commit comments

Comments
 (0)