Skip to content

[GlobalISel] Add boolean predicated legalization action methods. #111287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 16, 2024

Conversation

davemgreen
Copy link
Collaborator

Under AArch64 it is common and will become more common to have operation legalization rules dependant on a feature of the architecture. For example HasFP16 or the newer CSSC interger min/max instructions, almong many others. With the current legalization rules this either means adding a custom predicate based on the feature as in
legalIf([=](const LegalityQuery &Query) { return HasFP16 && ...; } or splitting the legalization rules into pieces that place rules optionally into them base on the features available.

This patch proposes an alterative where the existing routines like legalFor(..) are provided a boolean predicate, which if false skips adding the rule. It makes the rules cleaner and will hopefully allow them to scale better as we add more features.

The SVE predicates for loads/stores I have changed to just be always available. Scalable vectors without SVE have never been supported, but it could also add a condition.

Under AArch64 it is common and will become more common to have operation
legalization rules dependant on a feature of the architecture. For example
HasFP16 or the newer CSSC interger min/max instructions, almong many others.
With the current legalization rules this either means adding a custom predicate
based on the feature as in
`legalIf([=](const LegalityQuery &Query) { return HasFP16 && ...; }`
or splitting the legalization rules into pieces that place rules optionally
into them base on the features available.

This patch proposes an alterative where the existing routines like legalFor(..)
are provided a boolean predicate, which if false skips adding the rule. It
makes the rules cleaner and will hopefully allow them to scale better as we add
more features.
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: David Green (davemgreen)

Changes

Under AArch64 it is common and will become more common to have operation legalization rules dependant on a feature of the architecture. For example HasFP16 or the newer CSSC interger min/max instructions, almong many others. With the current legalization rules this either means adding a custom predicate based on the feature as in
legalIf([=](const LegalityQuery &Query) { return HasFP16 && ...; } or splitting the legalization rules into pieces that place rules optionally into them base on the features available.

This patch proposes an alterative where the existing routines like legalFor(..) are provided a boolean predicate, which if false skips adding the rule. It makes the rules cleaner and will hopefully allow them to scale better as we add more features.

The SVE predicates for loads/stores I have changed to just be always available. Scalable vectors without SVE have never been supported, but it could also add a condition.


Patch is 25.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111287.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+29-2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+54-127)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+36-36)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 82e713f30ea31c..4e5a6cf92b761d 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -599,11 +599,22 @@ class LegalizeRuleSet {
   LegalizeRuleSet &legalFor(std::initializer_list<LLT> Types) {
     return actionFor(LegalizeAction::Legal, Types);
   }
+  LegalizeRuleSet &legalFor(bool Pred, std::initializer_list<LLT> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Legal, Types);
+  }
   /// The instruction is legal when type indexes 0 and 1 is any type pair in the
   /// given list.
   LegalizeRuleSet &legalFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
     return actionFor(LegalizeAction::Legal, Types);
   }
+  LegalizeRuleSet &legalFor(bool Pred,
+                            std::initializer_list<std::pair<LLT, LLT>> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Legal, Types);
+  }
   /// The instruction is legal when type index 0 is any type in the given list
   /// and imm index 0 is anything.
   LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
@@ -846,12 +857,23 @@ class LegalizeRuleSet {
   LegalizeRuleSet &customFor(std::initializer_list<LLT> Types) {
     return actionFor(LegalizeAction::Custom, Types);
   }
+  LegalizeRuleSet &customFor(bool Pred, std::initializer_list<LLT> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Custom, Types);
+  }
 
-  /// The instruction is custom when type indexes 0 and 1 is any type pair in the
-  /// given list.
+  /// The instruction is custom when type indexes 0 and 1 is any type pair in
+  /// the given list.
   LegalizeRuleSet &customFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
     return actionFor(LegalizeAction::Custom, Types);
   }
+  LegalizeRuleSet &customFor(bool Pred,
+                             std::initializer_list<std::pair<LLT, LLT>> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Custom, Types);
+  }
 
   LegalizeRuleSet &customForCartesianProduct(std::initializer_list<LLT> Types) {
     return actionForCartesianProduct(LegalizeAction::Custom, Types);
@@ -990,6 +1012,11 @@ class LegalizeRuleSet {
                     scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
                     changeTo(typeIdx(TypeIdx), Ty));
   }
+  LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
+    if (!Pred)
+      return *this;
+    return minScalar(TypeIdx, Ty);
+  }
 
   /// Ensure the scalar is at least as wide as Ty if condition is met.
   LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 3d313ca00f1259..9b38b0f89a679f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -215,19 +215,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .legalFor({s64, v8s16, v16s8, v4s32})
       .lower();
 
-  auto &MinMaxActions = getActionDefinitionsBuilder(
-      {G_SMIN, G_SMAX, G_UMIN, G_UMAX});
-  if (HasCSSC)
-    MinMaxActions
-        .legalFor({s32, s64, v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
-        // Making clamping conditional on CSSC extension as without legal types we
-        // lower to CMP which can fold one of the two sxtb's we'd otherwise need
-        // if we detect a type smaller than 32-bit.
-        .minScalar(0, s32);
-  else
-    MinMaxActions
-        .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32});
-  MinMaxActions
+  getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
+      .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
+      .legalFor(HasCSSC, {s32, s64})
+      .minScalar(HasCSSC, 0, s32)
       .clampNumElements(0, v8s8, v16s8)
       .clampNumElements(0, v4s16, v8s16)
       .clampNumElements(0, v2s32, v4s32)
@@ -247,11 +238,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       {G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FSQRT, G_FMAXNUM, G_FMINNUM,
        G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR, G_FRINT, G_FNEARBYINT,
        G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        return (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor({s32, s64, v2s32, v4s32, v2s64})
+      .legalFor(HasFP16, {s16, v4s16, v8s16})
       .libcallFor({s128})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .minScalarOrElt(0, MinFPScalar)
@@ -261,11 +249,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .moreElementsToNextPow2(0);
 
   getActionDefinitionsBuilder({G_FABS, G_FNEG})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        return (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor({s32, s64, v2s32, v4s32, v2s64})
+      .legalFor(HasFP16, {s16, v4s16, v8s16})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .lowerIf(scalarOrEltWiderThan(0, 64))
       .clampNumElements(0, v4s16, v8s16)
@@ -350,31 +335,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
-  auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
-  auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
-
-  if (ST.hasSVE()) {
-    LoadActions.legalForTypesWithMemDesc({
-        // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 8},
-        {nxv8s16, p0, nxv8s16, 8},
-        {nxv4s32, p0, nxv4s32, 8},
-        {nxv2s64, p0, nxv2s64, 8},
-    });
-
-    // TODO: Add nxv2p0. Consider bitcastIf.
-    //       See #92130
-    //       https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
-    StoreActions.legalForTypesWithMemDesc({
-        // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 8},
-        {nxv8s16, p0, nxv8s16, 8},
-        {nxv4s32, p0, nxv4s32, 8},
-        {nxv2s64, p0, nxv2s64, 8},
-    });
-  }
-
-  LoadActions
+  getActionDefinitionsBuilder(G_LOAD)
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -399,6 +360,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // These extends are also legal
       .legalForTypesWithMemDesc(
           {{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
+      .legalForTypesWithMemDesc({
+          // SVE vscale x 128 bit base sizes
+          {nxv16s8, p0, nxv16s8, 8},
+          {nxv8s16, p0, nxv8s16, 8},
+          {nxv4s32, p0, nxv4s32, 8},
+          {nxv2s64, p0, nxv2s64, 8},
+      })
       .widenScalarToNextPow2(0, /* MinSize = */ 8)
       .clampMaxNumElements(0, s8, 16)
       .clampMaxNumElements(0, s16, 8)
@@ -424,7 +392,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(IsPtrVecPred)
       .scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
 
-  StoreActions
+  getActionDefinitionsBuilder(G_STORE)
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
@@ -444,6 +412,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
            {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
            {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
            {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
+      .legalForTypesWithMemDesc({
+          // SVE vscale x 128 bit base sizes
+          // TODO: Add nxv2p0. Consider bitcastIf.
+          //       See #92130
+          // https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
+          {nxv16s8, p0, nxv16s8, 8},
+          {nxv8s16, p0, nxv8s16, 8},
+          {nxv4s32, p0, nxv4s32, 8},
+          {nxv2s64, p0, nxv2s64, 8},
+      })
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&
@@ -530,12 +508,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .widenScalarToNextPow2(0)
       .clampScalar(0, s8, s64);
   getActionDefinitionsBuilder(G_FCONSTANT)
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        if (HasFP16 && Ty == s16)
-          return true;
-        return Ty == s32 || Ty == s64 || Ty == s128;
-      })
+      .legalFor({s32, s64, s128})
+      .legalFor(HasFP16, {s16})
       .clampScalar(0, MinFPScalar, s128);
 
   // FIXME: fix moreElementsToNextPow2
@@ -567,16 +541,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(isVector(0));
 
   getActionDefinitionsBuilder(G_FCMP)
-      .legalFor({{s32, MinFPScalar},
-                 {s32, s32},
+      .legalFor({{s32, s32},
                  {s32, s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32},
                  {v2s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return (Ty == v8s16 || Ty == v4s16) && Ty == Query.Types[0] && HasFP16;
-      })
+      .legalFor(HasFP16, {{s32, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       .widenScalarOrEltToNextPow2(1)
       .clampScalar(0, s32, s32)
       .minScalarOrElt(1, MinFPScalar)
@@ -691,13 +661,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
-                Query.Types[1] == v8s16) &&
-               (Query.Types[0] == s32 || Query.Types[0] == s64 ||
-                Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
       // The range of a fp16 value fits into an i17, so we can lower the width
@@ -739,13 +704,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
-                Query.Types[1] == v8s16) &&
-               (Query.Types[0] == s32 || Query.Types[0] == s64 ||
-                Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       // Handle types larger than i64 by scalarizing/lowering.
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
@@ -788,13 +748,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[0] == s16 || Query.Types[0] == v4s16 ||
-                Query.Types[0] == v8s16) &&
-               (Query.Types[1] == s32 || Query.Types[1] == s64 ||
-                Query.Types[1] == v4s16 || Query.Types[1] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s16, s32}, {s16, s64}, {v4s16, v4s16}, {v8s16, v8s16}})
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .moreElementsToNextPow2(1)
@@ -1048,12 +1003,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .widenScalarToNextPow2(1, /*Min=*/32)
       .clampScalar(1, s32, s64)
       .scalarSameSizeAs(0, 1)
-      .legalIf([=](const LegalityQuery &Query) {
-        return (HasCSSC && typeInSet(0, {s32, s64})(Query));
-      })
-      .customIf([=](const LegalityQuery &Query) {
-        return (!HasCSSC && typeInSet(0, {s32, s64})(Query));
-      });
+      .legalFor(HasCSSC, {s32, s64})
+      .customFor(!HasCSSC, {s32, s64});
 
   getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
       .legalIf([=](const LegalityQuery &Query) {
@@ -1141,11 +1092,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   }
 
   // FIXME: Legal vector types are only legal with NEON.
-  auto &ABSActions = getActionDefinitionsBuilder(G_ABS);
-  if (HasCSSC)
-    ABSActions
-        .legalFor({s32, s64});
-  ABSActions.legalFor(PackedVectorAllTypeList)
+  getActionDefinitionsBuilder(G_ABS)
+      .legalFor(HasCSSC, {s32, s64})
+      .legalFor(PackedVectorAllTypeList)
       .customIf([=](const LegalityQuery &Q) {
         // TODO: Fix suboptimal codegen for 128+ bit types.
         LLT SrcTy = Q.Types[0];
@@ -1169,10 +1118,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   // later.
   getActionDefinitionsBuilder(G_VECREDUCE_FADD)
       .legalFor({{s32, v2s32}, {s32, v4s32}, {s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return (Ty == v4s16 || Ty == v8s16) && HasFP16;
-      })
+      .legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
       .minScalarOrElt(0, MinFPScalar)
       .clampMaxNumElements(1, s64, 2)
       .clampMaxNumElements(1, s32, 4)
@@ -1213,10 +1159,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   getActionDefinitionsBuilder({G_VECREDUCE_FMIN, G_VECREDUCE_FMAX,
                                G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM})
       .legalFor({{s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return Query.Types[0] == s16 && (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
       .minScalarOrElt(0, MinFPScalar)
       .clampMaxNumElements(1, s64, 2)
       .clampMaxNumElements(1, s32, 4)
@@ -1293,32 +1236,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customFor({{s32, s32}, {s64, s64}});
 
   auto always = [=](const LegalityQuery &Q) { return true; };
-  auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
-  if (HasCSSC)
-    CTPOPActions
-        .legalFor({{s32, s32},
-                   {s64, s64},
-                   {v8s8, v8s8},
-                   {v16s8, v16s8}})
-        .customFor({{s128, s128},
-                    {v2s64, v2s64},
-                    {v2s32, v2s32},
-                    {v4s32, v4s32},
-                    {v4s16, v4s16},
-                    {v8s16, v8s16}});
-  else
-    CTPOPActions
-        .legalFor({{v8s8, v8s8},
-                   {v16s8, v16s8}})
-        .customFor({{s32, s32},
-                    {s64, s64},
-                    {s128, s128},
-                    {v2s64, v2s64},
-                    {v2s32, v2s32},
-                    {v4s32, v4s32},
-                    {v4s16, v4s16},
-                    {v8s16, v8s16}});
-  CTPOPActions
+  getActionDefinitionsBuilder(G_CTPOP)
+      .legalFor(HasCSSC, {{s32, s32}, {s64, s64}})
+      .legalFor({{v8s8, v8s8}, {v16s8, v16s8}})
+      .customFor(!HasCSSC, {{s32, s32}, {s64, s64}})
+      .customFor({{s128, s128},
+                  {v2s64, v2s64},
+                  {v2s32, v2s32},
+                  {v4s32, v4s32},
+                  {v4s16, v4s16},
+                  {v8s16, v8s16}})
       .clampScalar(0, s32, s128)
       .widenScalarToNextPow2(0)
       .minScalarEltSameAsIf(always, 1, 0)
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index a21b786a2bae97..073c3cafa062d3 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -152,12 +152,12 @@
 #
 # DEBUG-NEXT: G_INTRINSIC_TRUNC (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_ROUND (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_LRINT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. the first uncovered type index: 2, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
@@ -167,8 +167,8 @@
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_ROUNDEVEN (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_READCYCLECOUNTER (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
@@ -310,8 +310,8 @@
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FCONSTANT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_VASTART (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
@@ -459,27 +459,27 @@
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: G_FADD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FSUB (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FMUL (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FMA (opcode {{[0-9]+}}):...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2024

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

Changes

Under AArch64 it is common and will become more common to have operation legalization rules dependant on a feature of the architecture. For example HasFP16 or the newer CSSC interger min/max instructions, almong many others. With the current legalization rules this either means adding a custom predicate based on the feature as in
legalIf([=](const LegalityQuery &amp;Query) { return HasFP16 &amp;&amp; ...; } or splitting the legalization rules into pieces that place rules optionally into them base on the features available.

This patch proposes an alterative where the existing routines like legalFor(..) are provided a boolean predicate, which if false skips adding the rule. It makes the rules cleaner and will hopefully allow them to scale better as we add more features.

The SVE predicates for loads/stores I have changed to just be always available. Scalable vectors without SVE have never been supported, but it could also add a condition.


Patch is 25.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111287.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+29-2)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+54-127)
  • (modified) llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir (+36-36)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 82e713f30ea31c..4e5a6cf92b761d 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -599,11 +599,22 @@ class LegalizeRuleSet {
   LegalizeRuleSet &legalFor(std::initializer_list<LLT> Types) {
     return actionFor(LegalizeAction::Legal, Types);
   }
+  LegalizeRuleSet &legalFor(bool Pred, std::initializer_list<LLT> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Legal, Types);
+  }
   /// The instruction is legal when type indexes 0 and 1 is any type pair in the
   /// given list.
   LegalizeRuleSet &legalFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
     return actionFor(LegalizeAction::Legal, Types);
   }
+  LegalizeRuleSet &legalFor(bool Pred,
+                            std::initializer_list<std::pair<LLT, LLT>> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Legal, Types);
+  }
   /// The instruction is legal when type index 0 is any type in the given list
   /// and imm index 0 is anything.
   LegalizeRuleSet &legalForTypeWithAnyImm(std::initializer_list<LLT> Types) {
@@ -846,12 +857,23 @@ class LegalizeRuleSet {
   LegalizeRuleSet &customFor(std::initializer_list<LLT> Types) {
     return actionFor(LegalizeAction::Custom, Types);
   }
+  LegalizeRuleSet &customFor(bool Pred, std::initializer_list<LLT> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Custom, Types);
+  }
 
-  /// The instruction is custom when type indexes 0 and 1 is any type pair in the
-  /// given list.
+  /// The instruction is custom when type indexes 0 and 1 is any type pair in
+  /// the given list.
   LegalizeRuleSet &customFor(std::initializer_list<std::pair<LLT, LLT>> Types) {
     return actionFor(LegalizeAction::Custom, Types);
   }
+  LegalizeRuleSet &customFor(bool Pred,
+                             std::initializer_list<std::pair<LLT, LLT>> Types) {
+    if (!Pred)
+      return *this;
+    return actionFor(LegalizeAction::Custom, Types);
+  }
 
   LegalizeRuleSet &customForCartesianProduct(std::initializer_list<LLT> Types) {
     return actionForCartesianProduct(LegalizeAction::Custom, Types);
@@ -990,6 +1012,11 @@ class LegalizeRuleSet {
                     scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()),
                     changeTo(typeIdx(TypeIdx), Ty));
   }
+  LegalizeRuleSet &minScalar(bool Pred, unsigned TypeIdx, const LLT Ty) {
+    if (!Pred)
+      return *this;
+    return minScalar(TypeIdx, Ty);
+  }
 
   /// Ensure the scalar is at least as wide as Ty if condition is met.
   LegalizeRuleSet &minScalarIf(LegalityPredicate Predicate, unsigned TypeIdx,
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 3d313ca00f1259..9b38b0f89a679f 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -215,19 +215,10 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .legalFor({s64, v8s16, v16s8, v4s32})
       .lower();
 
-  auto &MinMaxActions = getActionDefinitionsBuilder(
-      {G_SMIN, G_SMAX, G_UMIN, G_UMAX});
-  if (HasCSSC)
-    MinMaxActions
-        .legalFor({s32, s64, v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
-        // Making clamping conditional on CSSC extension as without legal types we
-        // lower to CMP which can fold one of the two sxtb's we'd otherwise need
-        // if we detect a type smaller than 32-bit.
-        .minScalar(0, s32);
-  else
-    MinMaxActions
-        .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32});
-  MinMaxActions
+  getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX})
+      .legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
+      .legalFor(HasCSSC, {s32, s64})
+      .minScalar(HasCSSC, 0, s32)
       .clampNumElements(0, v8s8, v16s8)
       .clampNumElements(0, v4s16, v8s16)
       .clampNumElements(0, v2s32, v4s32)
@@ -247,11 +238,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       {G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FSQRT, G_FMAXNUM, G_FMINNUM,
        G_FMAXIMUM, G_FMINIMUM, G_FCEIL, G_FFLOOR, G_FRINT, G_FNEARBYINT,
        G_INTRINSIC_TRUNC, G_INTRINSIC_ROUND, G_INTRINSIC_ROUNDEVEN})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        return (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor({s32, s64, v2s32, v4s32, v2s64})
+      .legalFor(HasFP16, {s16, v4s16, v8s16})
       .libcallFor({s128})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .minScalarOrElt(0, MinFPScalar)
@@ -261,11 +249,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .moreElementsToNextPow2(0);
 
   getActionDefinitionsBuilder({G_FABS, G_FNEG})
-      .legalFor({MinFPScalar, s32, s64, v2s32, v4s32, v2s64})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        return (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor({s32, s64, v2s32, v4s32, v2s64})
+      .legalFor(HasFP16, {s16, v4s16, v8s16})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .lowerIf(scalarOrEltWiderThan(0, 64))
       .clampNumElements(0, v4s16, v8s16)
@@ -350,31 +335,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
     return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
-  auto &LoadActions = getActionDefinitionsBuilder(G_LOAD);
-  auto &StoreActions = getActionDefinitionsBuilder(G_STORE);
-
-  if (ST.hasSVE()) {
-    LoadActions.legalForTypesWithMemDesc({
-        // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 8},
-        {nxv8s16, p0, nxv8s16, 8},
-        {nxv4s32, p0, nxv4s32, 8},
-        {nxv2s64, p0, nxv2s64, 8},
-    });
-
-    // TODO: Add nxv2p0. Consider bitcastIf.
-    //       See #92130
-    //       https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
-    StoreActions.legalForTypesWithMemDesc({
-        // 128 bit base sizes
-        {nxv16s8, p0, nxv16s8, 8},
-        {nxv8s16, p0, nxv8s16, 8},
-        {nxv4s32, p0, nxv4s32, 8},
-        {nxv2s64, p0, nxv2s64, 8},
-    });
-  }
-
-  LoadActions
+  getActionDefinitionsBuilder(G_LOAD)
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Acquire;
@@ -399,6 +360,13 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       // These extends are also legal
       .legalForTypesWithMemDesc(
           {{s32, p0, s8, 8}, {s32, p0, s16, 8}, {s64, p0, s32, 8}})
+      .legalForTypesWithMemDesc({
+          // SVE vscale x 128 bit base sizes
+          {nxv16s8, p0, nxv16s8, 8},
+          {nxv8s16, p0, nxv8s16, 8},
+          {nxv4s32, p0, nxv4s32, 8},
+          {nxv2s64, p0, nxv2s64, 8},
+      })
       .widenScalarToNextPow2(0, /* MinSize = */ 8)
       .clampMaxNumElements(0, s8, 16)
       .clampMaxNumElements(0, s16, 8)
@@ -424,7 +392,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(IsPtrVecPred)
       .scalarizeIf(typeInSet(0, {v2s16, v2s8}), 0);
 
-  StoreActions
+  getActionDefinitionsBuilder(G_STORE)
       .customIf([=](const LegalityQuery &Query) {
         return HasRCPC3 && Query.Types[0] == s128 &&
                Query.MMODescrs[0].Ordering == AtomicOrdering::Release;
@@ -444,6 +412,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
            {p0, p0, s64, 8},    {s128, p0, s128, 8},  {v16s8, p0, s128, 8},
            {v8s8, p0, s64, 8},  {v4s16, p0, s64, 8},  {v8s16, p0, s128, 8},
            {v2s32, p0, s64, 8}, {v4s32, p0, s128, 8}, {v2s64, p0, s128, 8}})
+      .legalForTypesWithMemDesc({
+          // SVE vscale x 128 bit base sizes
+          // TODO: Add nxv2p0. Consider bitcastIf.
+          //       See #92130
+          // https://github.com/llvm/llvm-project/pull/92130#discussion_r1616888461
+          {nxv16s8, p0, nxv16s8, 8},
+          {nxv8s16, p0, nxv8s16, 8},
+          {nxv4s32, p0, nxv4s32, 8},
+          {nxv2s64, p0, nxv2s64, 8},
+      })
       .clampScalar(0, s8, s64)
       .lowerIf([=](const LegalityQuery &Query) {
         return Query.Types[0].isScalar() &&
@@ -530,12 +508,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .widenScalarToNextPow2(0)
       .clampScalar(0, s8, s64);
   getActionDefinitionsBuilder(G_FCONSTANT)
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[0];
-        if (HasFP16 && Ty == s16)
-          return true;
-        return Ty == s32 || Ty == s64 || Ty == s128;
-      })
+      .legalFor({s32, s64, s128})
+      .legalFor(HasFP16, {s16})
       .clampScalar(0, MinFPScalar, s128);
 
   // FIXME: fix moreElementsToNextPow2
@@ -567,16 +541,12 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customIf(isVector(0));
 
   getActionDefinitionsBuilder(G_FCMP)
-      .legalFor({{s32, MinFPScalar},
-                 {s32, s32},
+      .legalFor({{s32, s32},
                  {s32, s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32},
                  {v2s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return (Ty == v8s16 || Ty == v4s16) && Ty == Query.Types[0] && HasFP16;
-      })
+      .legalFor(HasFP16, {{s32, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       .widenScalarOrEltToNextPow2(1)
       .clampScalar(0, s32, s32)
       .minScalarOrElt(1, MinFPScalar)
@@ -691,13 +661,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
-                Query.Types[1] == v8s16) &&
-               (Query.Types[0] == s32 || Query.Types[0] == s64 ||
-                Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
       // The range of a fp16 value fits into an i17, so we can lower the width
@@ -739,13 +704,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[1] == s16 || Query.Types[1] == v4s16 ||
-                Query.Types[1] == v8s16) &&
-               (Query.Types[0] == s32 || Query.Types[0] == s64 ||
-                Query.Types[0] == v4s16 || Query.Types[0] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s32, s16}, {s64, s16}, {v4s16, v4s16}, {v8s16, v8s16}})
       // Handle types larger than i64 by scalarizing/lowering.
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
@@ -788,13 +748,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
                  {v2s64, v2s64},
                  {v4s32, v4s32},
                  {v2s32, v2s32}})
-      .legalIf([=](const LegalityQuery &Query) {
-        return HasFP16 &&
-               (Query.Types[0] == s16 || Query.Types[0] == v4s16 ||
-                Query.Types[0] == v8s16) &&
-               (Query.Types[1] == s32 || Query.Types[1] == s64 ||
-                Query.Types[1] == v4s16 || Query.Types[1] == v8s16);
-      })
+      .legalFor(HasFP16,
+                {{s16, s32}, {s16, s64}, {v4s16, v4s16}, {v8s16, v8s16}})
       .scalarizeIf(scalarOrEltWiderThan(1, 64), 1)
       .scalarizeIf(scalarOrEltWiderThan(0, 64), 0)
       .moreElementsToNextPow2(1)
@@ -1048,12 +1003,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .widenScalarToNextPow2(1, /*Min=*/32)
       .clampScalar(1, s32, s64)
       .scalarSameSizeAs(0, 1)
-      .legalIf([=](const LegalityQuery &Query) {
-        return (HasCSSC && typeInSet(0, {s32, s64})(Query));
-      })
-      .customIf([=](const LegalityQuery &Query) {
-        return (!HasCSSC && typeInSet(0, {s32, s64})(Query));
-      });
+      .legalFor(HasCSSC, {s32, s64})
+      .customFor(!HasCSSC, {s32, s64});
 
   getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
       .legalIf([=](const LegalityQuery &Query) {
@@ -1141,11 +1092,9 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   }
 
   // FIXME: Legal vector types are only legal with NEON.
-  auto &ABSActions = getActionDefinitionsBuilder(G_ABS);
-  if (HasCSSC)
-    ABSActions
-        .legalFor({s32, s64});
-  ABSActions.legalFor(PackedVectorAllTypeList)
+  getActionDefinitionsBuilder(G_ABS)
+      .legalFor(HasCSSC, {s32, s64})
+      .legalFor(PackedVectorAllTypeList)
       .customIf([=](const LegalityQuery &Q) {
         // TODO: Fix suboptimal codegen for 128+ bit types.
         LLT SrcTy = Q.Types[0];
@@ -1169,10 +1118,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   // later.
   getActionDefinitionsBuilder(G_VECREDUCE_FADD)
       .legalFor({{s32, v2s32}, {s32, v4s32}, {s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return (Ty == v4s16 || Ty == v8s16) && HasFP16;
-      })
+      .legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
       .minScalarOrElt(0, MinFPScalar)
       .clampMaxNumElements(1, s64, 2)
       .clampMaxNumElements(1, s32, 4)
@@ -1213,10 +1159,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   getActionDefinitionsBuilder({G_VECREDUCE_FMIN, G_VECREDUCE_FMAX,
                                G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM})
       .legalFor({{s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
-      .legalIf([=](const LegalityQuery &Query) {
-        const auto &Ty = Query.Types[1];
-        return Query.Types[0] == s16 && (Ty == v8s16 || Ty == v4s16) && HasFP16;
-      })
+      .legalFor(HasFP16, {{s16, v4s16}, {s16, v8s16}})
       .minScalarOrElt(0, MinFPScalar)
       .clampMaxNumElements(1, s64, 2)
       .clampMaxNumElements(1, s32, 4)
@@ -1293,32 +1236,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
       .customFor({{s32, s32}, {s64, s64}});
 
   auto always = [=](const LegalityQuery &Q) { return true; };
-  auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
-  if (HasCSSC)
-    CTPOPActions
-        .legalFor({{s32, s32},
-                   {s64, s64},
-                   {v8s8, v8s8},
-                   {v16s8, v16s8}})
-        .customFor({{s128, s128},
-                    {v2s64, v2s64},
-                    {v2s32, v2s32},
-                    {v4s32, v4s32},
-                    {v4s16, v4s16},
-                    {v8s16, v8s16}});
-  else
-    CTPOPActions
-        .legalFor({{v8s8, v8s8},
-                   {v16s8, v16s8}})
-        .customFor({{s32, s32},
-                    {s64, s64},
-                    {s128, s128},
-                    {v2s64, v2s64},
-                    {v2s32, v2s32},
-                    {v4s32, v4s32},
-                    {v4s16, v4s16},
-                    {v8s16, v8s16}});
-  CTPOPActions
+  getActionDefinitionsBuilder(G_CTPOP)
+      .legalFor(HasCSSC, {{s32, s32}, {s64, s64}})
+      .legalFor({{v8s8, v8s8}, {v16s8, v16s8}})
+      .customFor(!HasCSSC, {{s32, s32}, {s64, s64}})
+      .customFor({{s128, s128},
+                  {v2s64, v2s64},
+                  {v2s32, v2s32},
+                  {v4s32, v4s32},
+                  {v4s16, v4s16},
+                  {v8s16, v8s16}})
       .clampScalar(0, s32, s128)
       .widenScalarToNextPow2(0)
       .minScalarEltSameAsIf(always, 1, 0)
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
index a21b786a2bae97..073c3cafa062d3 100644
--- a/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir
@@ -152,12 +152,12 @@
 #
 # DEBUG-NEXT: G_INTRINSIC_TRUNC (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_ROUND (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_LRINT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
 # DEBUG-NEXT: .. the first uncovered type index: 2, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
@@ -167,8 +167,8 @@
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_INTRINSIC_ROUNDEVEN (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_READCYCLECOUNTER (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
@@ -310,8 +310,8 @@
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FCONSTANT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_VASTART (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. the first uncovered type index: 1, OK
 # DEBUG-NEXT: .. the first uncovered imm index: 0, OK
@@ -459,27 +459,27 @@
 # DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
 # DEBUG-NEXT: G_FADD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FSUB (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FMUL (opcode {{[0-9]+}}): 1 type index, 0 imm indices
 # DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
-# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
-# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
+# DEBUG-NEXT: .. the first uncovered type index: 1, OK
+# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
 # DEBUG-NEXT: G_FMA (opcode {{[0-9]+}}):...
[truncated]

@tschuett
Copy link

tschuett commented Oct 6, 2024

For the new legalFor and customFor, I have a slight preference for replacing For with If . Then it is clear that the first parameter is a predicate and not a type.

@aemerson
Copy link
Contributor

aemerson commented Oct 7, 2024

For the new legalFor and customFor, I have a slight preference for replacing For with If . Then it is clear that the first parameter is a predicate and not a type.

We already have a predicated legalIf which takes a legality predicate.

I'm not sure of this change. I get the motivation but I'm wondering if we have a better way to express it than having to plumb in a boolean into specific rules. Perhaps a guard wrapper or something so we can write:

AddIf(Cond, legalFor({s32, s64}))

and that way it generalizes to all of the rule methods.

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a legalIf, but that should be for complex predicates. This has the advantage of not actually adding it to the rule set for the subtarget

@Him188
Copy link
Member

Him188 commented Oct 11, 2024

Adding one more idea - I attempted to propose a predicate hasFP16() to be used in combination of all. So a typical usage looks like:

.legalIf(all(hasFP16(), typeInSet(0, {v8s16, v4s16})))

Which sounds like @aemerson 's idea but does not require a new AddIf.

See
#104753 (comment)

@tschuett
Copy link

tschuett commented Oct 11, 2024

Store common legality queries at the top of the file:

auto HasSVEAndTypes = [=](const LegalityQuery &Query) {
   return HasSVE && typeSetIn(0, {nxv16s8, nxv8s16, nxv4s32, nxv2s64});
},

used them in legalIf

.legalIf(HasSVEAndTypes)

@davemgreen
Copy link
Collaborator Author

For the new legalFor and customFor, I have a slight preference for replacing For with If . Then it is clear that the first parameter is a predicate and not a type.

We already have a predicated legalIf which takes a legality predicate.

I'm not sure of this change. I get the motivation but I'm wondering if we have a better way to express it than having to plumb in a boolean into specific rules. Perhaps a guard wrapper or something so we can write:

AddIf(Cond, legalFor({s32, s64}))

and that way it generalizes to all of the rule methods.

The was one thing I thought of, yeah. But I'm not sure how it would work. Do you know how that would work in C++, how would it pass the legalFor method through to the AddIf, it they are both part of the LegalizeRuleSet?

Adding one more idea - I attempted to propose a predicate hasFP16() to be used in combination of all. So a typical usage looks like:

.legalIf(all(hasFP16(), typeInSet(0, {v8s16, v4s16})))

Which sounds like @aemerson 's idea but does not require a new AddIf.

See #104753 (comment)

Yep that was what I was trying to avoid. I wrote this back when that patch was created (hence my suggestion to remove it), but rewrote it again recently. IMO it will get very ugly if we have to use the low-level interface for everything, and this is a common enough issue that having a specific interface for it is probably worth while. It should also at least in theory be a little quicker if it doesn't need to store the hasFP16 predicate and keep checking it again and again.

@aemerson
Copy link
Contributor

You're right, I thought we could find some C++ magic to make this ergonomic but after trying myself I came up short. At least not without a restructuring of the LegalizeRuleSet infrastructure.

Comment on lines +219 to +220
.legalFor({v8s8, v16s8, v4s16, v8s16, v2s32, v4s32})
.legalFor(HasCSSC, {s32, s64})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this could fuse into one rule with all of the types in the list, but I doubt there's a way to make that pretty

@davemgreen davemgreen merged commit a3010c7 into llvm:main Oct 16, 2024
12 checks passed
@davemgreen davemgreen deleted the gh-gi-boolpreds branch October 16, 2024 13:43
davemgreen added a commit that referenced this pull request Oct 16, 2024
Similar to #111287, this moves the UseOutlineAtomics legalization rules to a
boolean predicate as opposed to needing the be nested functions.

There appeared to be a pair of redundant customIfs for s128 sizes (assuming
only scalars are supported).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants