Skip to content

[RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. #93578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16128,23 +16128,26 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}

// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
// This would be benefit for the cases where X and Y are both the same value
// type of low precision vectors. Since the truncate would be lowered into
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
// restriction, such pattern would be expanded into a series of "vsetvli"
// and "vnsrl" instructions later to reach this point.
static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
// This would be benefit for the cases where X and Y are both the same value
// type of low precision vectors. Since the truncate would be lowered into
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
// restriction, such pattern would be expanded into a series of "vsetvli"
// and "vnsrl" instructions later to reach this point.
auto IsTruncNode = [](SDValue V) {
if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
return false;
SDValue VL = V.getOperand(2);
auto *C = dyn_cast<ConstantSDNode>(VL);
// Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
(isa<RegisterSDNode>(VL) &&
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
SDValue Mask = N->getOperand(1);
SDValue VL = N->getOperand(2);

bool IsVLMAX = isAllOnesConstant(VL) ||
(isa<RegisterSDNode>(VL) &&
cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL ||
Mask.getOperand(0) != VL)
Copy link
Contributor

@lukel97 lukel97 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like IsVLMAXForVMSET didn't actually check the VMSET's VL in the first place. But since we can interpret any VMSET_VL as all ones as the tail is undefined would it make sense to drop the Mask's VL check here?

Copy link
Collaborator Author

@topperc topperc May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but we don't have tests for it.

return SDValue();

auto IsTruncNode = [&](SDValue V) {
return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
V.getOperand(1) == Mask && V.getOperand(2) == VL;
};

SDValue Op = N->getOperand(0);
Expand Down
14 changes: 9 additions & 5 deletions llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
Original file line number Diff line number Diff line change
Expand Up @@ -937,13 +937,17 @@ define <vscale x 8 x i32> @vsra_vi_mask_nxv8i32(<vscale x 8 x i32> %va, <vscale

; Negative test. We shouldn't look through the vp.trunc as it isn't vlmax like
; the rest of the code.
define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 %evl) {
define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext_mixed_trunc:
; CHECK: # %bb.0:
; CHECK-NEXT: li a0, 7
; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
; CHECK-NEXT: vmin.vx v9, v8, a0
; CHECK-NEXT: vsra.vv v8, v8, v9
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
; CHECK-NEXT: vsext.vf4 v9, v8
; CHECK-NEXT: vzext.vf4 v10, v8
; CHECK-NEXT: vsra.vv v8, v9, v10
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
; CHECK-NEXT: vnsrl.wi v8, v8, 0
; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t
; CHECK-NEXT: ret
%sexted_va = sext <vscale x 1 x i8> %va to <vscale x 1 x i32>
%zexted_vb = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>
Expand Down
Loading