Skip to content

[GlobalISel] Add boolean predicated legalization action methods. #111287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,22 @@ class LegalizeRuleSet {
LegalizeRuleSet &legalFor(std::initializer_list<LLT> Types) {
return actionFor(LegalizeAction::Legal, Types);
}
LegalizeRuleSet &legalFor(bool Pred, std::initializer_list<LLT> Types) {
if (!Pred)
return *this;
return actionFor(LegalizeAction::Legal, Types);
}
/// The instruction is legal when type indexes 0 and 1 is any type pair in the
/// given list.
LegalizeRuleSet &legalFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
return actionFor(LegalizeAction::Legal, Types);
}
LegalizeRuleSet &legalFor(bool Pred,
std::initializer_list<std::pair<LLT, LLT>> Types) {
if (!Pred)
return *this;
return actionFor(LegalizeAction::Legal, Types);
}
/// The instruction is legal when type index 0 is any type in the given list
/// and imm index 0 is anything.
LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
Expand Down Expand Up @@ -846,12 +857,23 @@ class LegalizeRuleSet {
LegalizeRuleSet &customFor(std::initializer_list<LLT> Types) {
return actionFor(LegalizeAction::Custom, Types);
}
LegalizeRuleSet &customFor(bool Pred, std::initializer_list<LLT> Types) {
if (!Pred)
return *this;
return actionFor(LegalizeAction::Custom, Types);
}

/// The instruction is custom when type indexes 0 and 1 is any type pair in the
/// given list.
/// The instruction is custom when type indexes 0 and 1 is any type pair in
/// the given list.
LegalizeRuleSet &customFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
return actionFor(LegalizeAction::Custom, Types);
}
LegalizeRuleSet &customFor(bool Pred,
std::initializer_list<std::pair<LLT, LLT>> Types) {
if (!Pred)
return *this;
return actionFor(LegalizeAction::Custom, Types);
}

LegalizeRuleSet &customForCartesianProduct(std::initializer_list<LLT> Types) {
return actionForCartesianProduct(LegalizeAction::Custom, Types);
Expand Down Expand Up @@ -990,6 +1012,11 @@ class LegalizeRuleSet {
scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
changeTo(typeIdx(TypeIdx), Ty));
}
LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
if (!Pred)
return *this;
return minScalar(TypeIdx, Ty);
}

/// Ensure the scalar is at least as wide as Ty if condition is met.
LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,
Expand Down
181 changes: 54 additions & 127 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.legalFor({s64, v8s16, v16s8, v4s32})
.lower();

auto &MinMaxActions = getActionDefinitionsBuilder(
{G_SMIN, G_SMAX, G_UMIN, G_UMAX});
if (HasCSSC)
MinMaxActions
.legalFor({s32, s64, v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
// Making clamping conditional on CSSC extension as without legal types we
// lower to CMP which can fold one of the two sxtb's we'd otherwise need
// if we detect a type smaller than 32-bit.
.minScalar(0, s32);
else
MinMaxActions
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32});
MinMaxActions
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
.legalFor(HasCSSC, {s32, s64})
Comment on lines +219 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this could fuse into one rule with all of the types in the list, but I doubt there's a way to make that pretty

.minScalar(HasCSSC, 0, s32)
.clampNumElements(0, v8s8, v16s8)
.clampNumElements(0, v4s16, v8s16)
.clampNumElements(0, v2s32, v4s32)
Expand All @@ -247,11 +238,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FSQRT, G_FMAXNUM, G_FMINNUM,
G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR, G_FRINT, G_FNEARBYINT,
G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[0];
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
})
.legalFor({s32, s64, v2s32, v4s32, v2s64})
.legalFor(HasFP16, {s16, v4s16, v8s16})
.libcallFor({s128})
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.minScalarOrElt(0, MinFPScalar)
Expand All @@ -261,11 +249,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.moreElementsToNextPow2(0);

getActionDefinitionsBuilder({G_FABS, G_FNEG})
.legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[0];
return (Ty == v8s16 || Ty == v4s16) && HasFP16;
})
.legalFor({s32, s64, v2s32, v4s32, v2s64})
.legalFor(HasFP16, {s16, v4s16, v8s16})
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.lowerIf(scalarOrEltWiderThan(0, 64))
.clampNumElements(0, v4s16, v8s16)
Expand Down Expand Up @@ -350,31 +335,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
};

auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
auto &StoreActions = getActionDefinitionsBuilder(G_STORE);

if (ST.hasSVE()) {
LoadActions.legalForTypesWithMemDesc({
// 128 bit base sizes
{nxv16s8, p0, nxv16s8, 8},
{nxv8s16, p0, nxv8s16, 8},
{nxv4s32, p0, nxv4s32, 8},
{nxv2s64, p0, nxv2s64, 8},
});

// TODO: Add nxv2p0. Consider bitcastIf.
// See #92130
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
StoreActions.legalForTypesWithMemDesc({
// 128 bit base sizes
{nxv16s8, p0, nxv16s8, 8},
{nxv8s16, p0, nxv8s16, 8},
{nxv4s32, p0, nxv4s32, 8},
{nxv2s64, p0, nxv2s64, 8},
});
}

LoadActions
getActionDefinitionsBuilder(G_LOAD)
.customIf([=](const LegalityQuery &Query) {
return HasRCPC3 && Query.Types[0] == s128 &&
Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
Expand All @@ -399,6 +360,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
// These extends are also legal
.legalForTypesWithMemDesc(
{{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
.legalForTypesWithMemDesc({
// SVE vscale x 128 bit base sizes
{nxv16s8, p0, nxv16s8, 8},
{nxv8s16, p0, nxv8s16, 8},
{nxv4s32, p0, nxv4s32, 8},
{nxv2s64, p0, nxv2s64, 8},
})
.widenScalarToNextPow2(0, /* MinSize = */ 8)
.clampMaxNumElements(0, s8, 16)
.clampMaxNumElements(0, s16, 8)
Expand All @@ -424,7 +392,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.customIf(IsPtrVecPred)
.scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);

StoreActions
getActionDefinitionsBuilder(G_STORE)
.customIf([=](const LegalityQuery &Query) {
return HasRCPC3 && Query.Types[0] == s128 &&
Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
Expand All @@ -444,6 +412,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{p0, p0, s64, 8}, {s128, p0, s128, 8}, {v16s8, p0, s128, 8},
{v8s8, p0, s64, 8}, {v4s16, p0, s64, 8}, {v8s16, p0, s128, 8},
{v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
.legalForTypesWithMemDesc({
// SVE vscale x 128 bit base sizes
// TODO: Add nxv2p0. Consider bitcastIf.
// See #92130
// https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
{nxv16s8, p0, nxv16s8, 8},
{nxv8s16, p0, nxv8s16, 8},
{nxv4s32, p0, nxv4s32, 8},
{nxv2s64, p0, nxv2s64, 8},
})
.clampScalar(0, s8, s64)
.lowerIf([=](const LegalityQuery &Query) {
return Query.Types[0].isScalar() &&
Expand Down Expand Up @@ -530,12 +508,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.widenScalarToNextPow2(0)
.clampScalar(0, s8, s64);
getActionDefinitionsBuilder(G_FCONSTANT)
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[0];
if (HasFP16 && Ty == s16)
return true;
return Ty == s32 || Ty == s64 || Ty == s128;
})
.legalFor({s32, s64, s128})
.legalFor(HasFP16, {s16})
.clampScalar(0, MinFPScalar, s128);

// FIXME: fix moreElementsToNextPow2
Expand Down Expand Up @@ -567,16 +541,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.customIf(isVector(0));

getActionDefinitionsBuilder(G_FCMP)
.legalFor({{s32, MinFPScalar},
{s32, s32},
.legalFor({{s32, s32},
{s32, s64},
{v4s32, v4s32},
{v2s32, v2s32},
{v2s64, v2s64}})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[1];
return (Ty == v8s16 || Ty == v4s16) && Ty == Query.Types[0] && HasFP16;
})
.legalFor(HasFP16, {{s32, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
.widenScalarOrEltToNextPow2(1)
.clampScalar(0, s32, s32)
.minScalarOrElt(1, MinFPScalar)
Expand Down Expand Up @@ -691,13 +661,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{v2s64, v2s64},
{v4s32, v4s32},
{v2s32, v2s32}})
.legalIf([=](const LegalityQuery &Query) {
return HasFP16 &&
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
Query.Types[1] == v8s16) &&
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
})
.legalFor(HasFP16,
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
// The range of a fp16 value fits into an i17, so we can lower the width
Expand Down Expand Up @@ -739,13 +704,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{v2s64, v2s64},
{v4s32, v4s32},
{v2s32, v2s32}})
.legalIf([=](const LegalityQuery &Query) {
return HasFP16 &&
(Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
Query.Types[1] == v8s16) &&
(Query.Types[0] == s32 || Query.Types[0] == s64 ||
Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
})
.legalFor(HasFP16,
{{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
// Handle types larger than i64 by scalarizing/lowering.
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
Expand Down Expand Up @@ -788,13 +748,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
{v2s64, v2s64},
{v4s32, v4s32},
{v2s32, v2s32}})
.legalIf([=](const LegalityQuery &Query) {
return HasFP16 &&
(Query.Types[0] == s16 || Query.Types[0] == v4s16 ||
Query.Types[0] == v8s16) &&
(Query.Types[1] == s32 || Query.Types[1] == s64 ||
Query.Types[1] == v4s16 || Query.Types[1] == v8s16);
})
.legalFor(HasFP16,
{{s16, s32}, {s16, s64}, {v4s16, v4s16}, {v8s16, v8s16}})
.scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
.scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
.moreElementsToNextPow2(1)
Expand Down Expand Up @@ -1048,12 +1003,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.widenScalarToNextPow2(1, /*Min=*/32)
.clampScalar(1, s32, s64)
.scalarSameSizeAs(0, 1)
.legalIf([=](const LegalityQuery &Query) {
return (HasCSSC && typeInSet(0, {s32, s64})(Query));
})
.customIf([=](const LegalityQuery &Query) {
return (!HasCSSC && typeInSet(0, {s32, s64})(Query));
});
.legalFor(HasCSSC, {s32, s64})
.customFor(!HasCSSC, {s32, s64});

getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
.legalIf([=](const LegalityQuery &Query) {
Expand Down Expand Up @@ -1141,11 +1092,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
}

// FIXME: Legal vector types are only legal with NEON.
auto &ABSActions = getActionDefinitionsBuilder(G_ABS);
if (HasCSSC)
ABSActions
.legalFor({s32, s64});
ABSActions.legalFor(PackedVectorAllTypeList)
getActionDefinitionsBuilder(G_ABS)
.legalFor(HasCSSC, {s32, s64})
.legalFor(PackedVectorAllTypeList)
.customIf([=](const LegalityQuery &Q) {
// TODO: Fix suboptimal codegen for 128+ bit types.
LLT SrcTy = Q.Types[0];
Expand All @@ -1169,10 +1118,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
// later.
getActionDefinitionsBuilder(G_VECREDUCE_FADD)
.legalFor({{s32, v2s32}, {s32, v4s32}, {s64, v2s64}})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[1];
return (Ty == v4s16 || Ty == v8s16) && HasFP16;
})
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
.minScalarOrElt(0, MinFPScalar)
.clampMaxNumElements(1, s64, 2)
.clampMaxNumElements(1, s32, 4)
Expand Down Expand Up @@ -1213,10 +1159,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
getActionDefinitionsBuilder({G_VECREDUCE_FMIN, G_VECREDUCE_FMAX,
G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM})
.legalFor({{s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
.legalIf([=](const LegalityQuery &Query) {
const auto &Ty = Query.Types[1];
return Query.Types[0] == s16 && (Ty == v8s16 || Ty == v4s16) && HasFP16;
})
.legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
.minScalarOrElt(0, MinFPScalar)
.clampMaxNumElements(1, s64, 2)
.clampMaxNumElements(1, s32, 4)
Expand Down Expand Up @@ -1293,32 +1236,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.customFor({{s32, s32}, {s64, s64}});

auto always = [=](const LegalityQuery &Q) { return true; };
auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
if (HasCSSC)
CTPOPActions
.legalFor({{s32, s32},
{s64, s64},
{v8s8, v8s8},
{v16s8, v16s8}})
.customFor({{s128, s128},
{v2s64, v2s64},
{v2s32, v2s32},
{v4s32, v4s32},
{v4s16, v4s16},
{v8s16, v8s16}});
else
CTPOPActions
.legalFor({{v8s8, v8s8},
{v16s8, v16s8}})
.customFor({{s32, s32},
{s64, s64},
{s128, s128},
{v2s64, v2s64},
{v2s32, v2s32},
{v4s32, v4s32},
{v4s16, v4s16},
{v8s16, v8s16}});
CTPOPActions
getActionDefinitionsBuilder(G_CTPOP)
.legalFor(HasCSSC, {{s32, s32}, {s64, s64}})
.legalFor({{v8s8, v8s8}, {v16s8, v16s8}})
.customFor(!HasCSSC, {{s32, s32}, {s64, s64}})
.customFor({{s128, s128},
{v2s64, v2s64},
{v2s32, v2s32},
{v4s32, v4s32},
{v4s16, v4s16},
{v8s16, v8s16}})
.clampScalar(0, s32, s128)
.widenScalarToNextPow2(0)
.minScalarEltSameAsIf(always, 1, 0)
Expand Down
Loading
Loading