Skip to content

[RISCV] Add patterns for fixed vector vwsll #87316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3287,22 +3287,25 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
}

bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
// Truncates are custom lowered during legalization.
auto IsTrunc = [this](SDValue N) {
if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
auto IsVLNode = [this](SDValue N) {
switch (N->getOpcode()) {
case RISCVISD::TRUNCATE_VECTOR_VL:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
break;
default:
return false;
}
SDValue VL;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to select VL is we aren't going to use VL?

selectVLOp(N->getOperand(2), VL);
// Any vmset_vl is ok, since any bits past VL are undefined and we can
// assume they are set.
return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
isa<ConstantSDNode>(VL) &&
cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
// There's no passthru so any mask is ok, since any inactive elements will
// be undef.
return true;
Copy link
Collaborator

@topperc topperc Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This return is unreachable

};

// We can have multiple nested truncates, so unravel them all if needed.
// We can have multiple nested nodes, so unravel them all if needed.
while (N->getOpcode() == ISD::SIGN_EXTEND ||
N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
N->getOpcode() == ISD::ZERO_EXTEND || IsVLNode(N)) {
if (!N.hasOneUse() ||
N.getValueType().getSizeInBits().getKnownMinValue() < 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to this patch, but what is this N.getValueType().getSizeInBits().getKnownMinValue() doing? Shouldn't it be checking element size instead of the size of the vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that should definitely be getScalarSizeInBits()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed it in this PR 785af38, let me know if I should split it off

return false;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisVT<3, XLenVT>]>;
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
[(riscv_sext_vl node:$A, node:$B, node:$C),
(riscv_zext_vl node:$A, node:$B, node:$C)]>;

def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
Expand Down Expand Up @@ -535,6 +538,11 @@ def riscv_zext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
return N->hasOneUse();
}]>;

def riscv_ext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
(riscv_ext_vl node:$A, node:$B, node:$C), [{
return N->hasOneUse();
}]>;

def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
(riscv_fpextend_vl node:$A, node:$B, node:$C), [{
return N->hasOneUse();
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,19 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (riscv_ext_vl_oneuse
(vti.Vector vti.RegClass:$rs1),
(vti.Mask V0), VLOpFrag)),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
Expand All @@ -651,6 +664,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
Expand All @@ -660,6 +684,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_vwsll_vl
(vti.Vector vti.RegClass:$rs2),
(vti.Vector vti.RegClass:$rs1),
Expand Down
Loading