Skip to content

[RISCV][VLOPT] Allow users that are passthrus if tail elements aren't demanded #124066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,25 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
return std::nullopt;
}

unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");

// If the user is a passthru it will read the elements past VL, so
// abort if any of the elements past VL are demanded.
if (UserOp.isTied()) {
assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
auto DemandedVL = DemandedVLs[&UserMI];
if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
"instruction with demanded tail\n");
return std::nullopt;
}
}

// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it as only reading the first lane.
if (isVectorOpUsedAsScalarOp(UserOp)) {
Expand All @@ -1200,12 +1219,6 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
return MachineOperand::CreateImm(1);
}

unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");

// If we know the demanded VL of UserMI, then we can reduce the VL it
// requires.
if (auto DemandedVL = DemandedVLs[&UserMI]) {
Expand All @@ -1227,12 +1240,6 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
return std::nullopt;
}

// If used as a passthru, elements past VL will be read.
if (UserOp.isTied()) {
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
return std::nullopt;
}

auto VLOp = getMinimumVLForUser(UserOp);
if (!VLOp)
return std::nullopt;
Expand Down
23 changes: 23 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vl-opt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,26 @@ define <vscale x 4 x i32> @dont_optimize_tied_def(<vscale x 4 x i32> %a, <vscale
ret <vscale x 4 x i32> %2
}

define void @optimize_ternary_use(<vscale x 4 x i16> %a, <vscale x 4 x i32> %b, <vscale x 4 x i32> %c, ptr %p, iXLen %vl) {
; NOVLOPT-LABEL: optimize_ternary_use:
; NOVLOPT: # %bb.0:
; NOVLOPT-NEXT: vsetvli a2, zero, e32, m2, ta, ma
; NOVLOPT-NEXT: vzext.vf2 v14, v8
; NOVLOPT-NEXT: vsetvli zero, a1, e32, m2, ta, ma
; NOVLOPT-NEXT: vmadd.vv v14, v10, v12
; NOVLOPT-NEXT: vse32.v v14, (a0)
; NOVLOPT-NEXT: ret
;
; VLOPT-LABEL: optimize_ternary_use:
; VLOPT: # %bb.0:
; VLOPT-NEXT: vsetvli zero, a1, e32, m2, ta, ma
; VLOPT-NEXT: vzext.vf2 v14, v8
; VLOPT-NEXT: vmadd.vv v14, v10, v12
; VLOPT-NEXT: vse32.v v14, (a0)
; VLOPT-NEXT: ret
%1 = zext <vscale x 4 x i16> %a to <vscale x 4 x i32>
%2 = mul <vscale x 4 x i32> %b, %1
%3 = add <vscale x 4 x i32> %2, %c
call void @llvm.riscv.vse(<vscale x 4 x i32> %3, ptr %p, iXLen %vl)
ret void
}
61 changes: 61 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/vl-opt.mir
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,64 @@ body: |
bb.1:
%y:vr = PseudoVADD_VV_M1 $noreg, %x, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
PseudoRET
...
---
# Can reduce %x even though %y uses it as a passthru, because %y's inactive elements aren't demanded
name: passthru_not_demanded
body: |
bb.0:
; CHECK-LABEL: name: passthru_not_demanded
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %y, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
%y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%z:vr = PseudoVADD_VV_M1 $noreg, %y, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
...
---
# Can't reduce %x because %y uses it as a passthru, and %y's inactive elements are demanded by %z
name: passthru_demanded
body: |
bb.0:
; CHECK-LABEL: name: passthru_demanded
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 $noreg, %y, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
%y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%z:vr = PseudoVADD_VV_M1 $noreg, %y, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
...
---
# Can reduce %x even though %y uses it as a passthru, because %y's inactive elements aren't demanded
name: passthru_not_demanded_passthru_chain
body: |
bb.0:
; CHECK-LABEL: name: passthru_not_demanded_passthru_chain
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 %y, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %a:vr = PseudoVADD_VV_M1 %z, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %b:vr = PseudoVADD_VV_M1 $noreg, %a, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
%y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%z:vr = PseudoVADD_VV_M1 %y, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%a:vr = PseudoVADD_VV_M1 %z, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%b:vr = PseudoVADD_VV_M1 $noreg, %a, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
...
---
# Can't reduce %x because %y uses it as a passthru, and %y's inactive elements are ultimately demanded in %b
name: passthru_demanded_passthru_chain
body: |
bb.0:
; CHECK-LABEL: name: passthru_demanded_passthru_chain
; CHECK: %x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %z:vr = PseudoVADD_VV_M1 %y, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %a:vr = PseudoVADD_VV_M1 %z, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
; CHECK-NEXT: %b:vr = PseudoVADD_VV_M1 $noreg, %a, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
%x:vr = PseudoVADD_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */, 0 /* tu, mu */
%y:vr = PseudoVADD_VV_M1 %x, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%z:vr = PseudoVADD_VV_M1 %y, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%a:vr = PseudoVADD_VV_M1 %z, $noreg, $noreg, 1, 3 /* e8 */, 0 /* tu, mu */
%b:vr = PseudoVADD_VV_M1 $noreg, %a, $noreg, 2, 3 /* e8 */, 0 /* tu, mu */
...
36 changes: 12 additions & 24 deletions llvm/test/CodeGen/RISCV/rvv/vmadd-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1638,9 +1638,8 @@ define <vscale x 1 x i64> @vmadd_vx_nxv1i64(<vscale x 1 x i64> %a, i64 %b, <vsca
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vmadd.vv v10, v8, v9
; RV32-NEXT: vsetvli zero, zero, e64, m1, tu, ma
; RV32-NEXT: vmerge.vvm v8, v8, v10, v0
Expand Down Expand Up @@ -1669,9 +1668,8 @@ define <vscale x 1 x i64> @vmadd_vx_nxv1i64_unmasked(<vscale x 1 x i64> %a, i64
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vmadd.vv v10, v8, v9
; RV32-NEXT: vsetvli zero, zero, e64, m1, tu, ma
; RV32-NEXT: vmv.v.v v8, v10
Expand Down Expand Up @@ -1713,9 +1711,8 @@ define <vscale x 1 x i64> @vmadd_vx_nxv1i64_ta(<vscale x 1 x i64> %a, i64 %b, <v
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m1, ta, ma
; RV32-NEXT: vlse64.v v10, (a0), zero
; RV32-NEXT: vmadd.vv v10, v8, v9
; RV32-NEXT: vmerge.vvm v8, v8, v10, v0
; RV32-NEXT: addi sp, sp, 16
Expand Down Expand Up @@ -1776,9 +1773,8 @@ define <vscale x 2 x i64> @vmadd_vx_nxv2i64(<vscale x 2 x i64> %a, i64 %b, <vsca
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vmadd.vv v12, v8, v10
; RV32-NEXT: vsetvli zero, zero, e64, m2, tu, ma
; RV32-NEXT: vmerge.vvm v8, v8, v12, v0
Expand Down Expand Up @@ -1807,9 +1803,8 @@ define <vscale x 2 x i64> @vmadd_vx_nxv2i64_unmasked(<vscale x 2 x i64> %a, i64
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vmadd.vv v12, v8, v10
; RV32-NEXT: vsetvli zero, zero, e64, m2, tu, ma
; RV32-NEXT: vmv.v.v v8, v12
Expand Down Expand Up @@ -1851,9 +1846,8 @@ define <vscale x 2 x i64> @vmadd_vx_nxv2i64_ta(<vscale x 2 x i64> %a, i64 %b, <v
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m2, ta, ma
; RV32-NEXT: vlse64.v v12, (a0), zero
; RV32-NEXT: vmadd.vv v12, v8, v10
; RV32-NEXT: vmerge.vvm v8, v8, v12, v0
; RV32-NEXT: addi sp, sp, 16
Expand Down Expand Up @@ -1914,9 +1908,8 @@ define <vscale x 4 x i64> @vmadd_vx_nxv4i64(<vscale x 4 x i64> %a, i64 %b, <vsca
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vmadd.vv v16, v8, v12
; RV32-NEXT: vsetvli zero, zero, e64, m4, tu, ma
; RV32-NEXT: vmerge.vvm v8, v8, v16, v0
Expand Down Expand Up @@ -1945,9 +1938,8 @@ define <vscale x 4 x i64> @vmadd_vx_nxv4i64_unmasked(<vscale x 4 x i64> %a, i64
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vmadd.vv v16, v8, v12
; RV32-NEXT: vsetvli zero, zero, e64, m4, tu, ma
; RV32-NEXT: vmv.v.v v8, v16
Expand Down Expand Up @@ -1989,9 +1981,8 @@ define <vscale x 4 x i64> @vmadd_vx_nxv4i64_ta(<vscale x 4 x i64> %a, i64 %b, <v
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m4, ta, ma
; RV32-NEXT: vlse64.v v16, (a0), zero
; RV32-NEXT: vmadd.vv v16, v8, v12
; RV32-NEXT: vmerge.vvm v8, v8, v16, v0
; RV32-NEXT: addi sp, sp, 16
Expand Down Expand Up @@ -2054,9 +2045,8 @@ define <vscale x 8 x i64> @vmadd_vx_nxv8i64(<vscale x 8 x i64> %a, i64 %b, <vsca
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vmadd.vv v24, v8, v16
; RV32-NEXT: vsetvli zero, zero, e64, m8, tu, ma
; RV32-NEXT: vmerge.vvm v8, v8, v24, v0
Expand Down Expand Up @@ -2085,9 +2075,8 @@ define <vscale x 8 x i64> @vmadd_vx_nxv8i64_unmasked(<vscale x 8 x i64> %a, i64
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vmadd.vv v24, v8, v16
; RV32-NEXT: vsetvli zero, zero, e64, m8, tu, ma
; RV32-NEXT: vmv.v.v v8, v24
Expand Down Expand Up @@ -2130,9 +2119,8 @@ define <vscale x 8 x i64> @vmadd_vx_nxv8i64_ta(<vscale x 8 x i64> %a, i64 %b, <v
; RV32-NEXT: sw a0, 8(sp)
; RV32-NEXT: sw a1, 12(sp)
; RV32-NEXT: addi a0, sp, 8
; RV32-NEXT: vsetvli a1, zero, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; RV32-NEXT: vlse64.v v24, (a0), zero
; RV32-NEXT: vmadd.vv v24, v8, v16
; RV32-NEXT: vmerge.vvm v8, v8, v24, v0
; RV32-NEXT: addi sp, sp, 16
Expand Down