Skip to content

[RISCV] Add patterns for fixed vector vwsll #87316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3287,24 +3287,24 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
}

bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
// Truncates are custom lowered during legalization.
auto IsTrunc = [this](SDValue N) {
if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
auto IsExtOrTrunc = [](SDValue N) {
switch (N->getOpcode()) {
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
// There's no passthru on these _VL nodes so any VL/mask is ok, since any
// inactive elements will be undef.
case RISCVISD::TRUNCATE_VECTOR_VL:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
return true;
default:
return false;
SDValue VL;
selectVLOp(N->getOperand(2), VL);
// Any vmset_vl is ok, since any bits past VL are undefined and we can
// assume they are set.
return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
isa<ConstantSDNode>(VL) &&
cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
}
};

// We can have multiple nested truncates, so unravel them all if needed.
while (N->getOpcode() == ISD::SIGN_EXTEND ||
N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
if (!N.hasOneUse() ||
N.getValueType().getSizeInBits().getKnownMinValue() < 8)
// We can have multiple nested nodes, so unravel them all if needed.
while (IsExtOrTrunc(N)) {
if (!N.hasOneUse() || N.getScalarValueSizeInBits() < 8)
return false;
N = N->getOperand(0);
}
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisVT<3, XLenVT>]>;
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
[(riscv_sext_vl node:$A, node:$B, node:$C),
(riscv_zext_vl node:$A, node:$B, node:$C)]>;

def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
Expand Down Expand Up @@ -535,6 +538,11 @@ def riscv_zext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
return N->hasOneUse();
}]>;

def riscv_ext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
(riscv_ext_vl node:$A, node:$B, node:$C), [{
return N->hasOneUse();
}]>;

def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
(riscv_fpextend_vl node:$A, node:$B, node:$C), [{
return N->hasOneUse();
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,19 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (riscv_ext_vl_oneuse
(vti.Vector vti.RegClass:$rs1),
(vti.Mask V0), VLOpFrag)),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
Expand All @@ -651,6 +664,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
Expand All @@ -660,6 +684,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_shl_vl
(wti.Vector (riscv_zext_vl_oneuse
(vti.Vector vti.RegClass:$rs2),
(vti.Mask V0), VLOpFrag)),
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
(wti.Vector wti.RegClass:$merge),
(vti.Mask V0), VLOpFrag),
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;

def : Pat<(riscv_vwsll_vl
(vti.Vector vti.RegClass:$rs2),
(vti.Vector vti.RegClass:$rs1),
Expand Down
Loading