Skip to content

[MIPatternMatch] Add m_DeferredReg/Type #121218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 30, 2024

Conversation

mshockwave
Copy link
Member

This pattern does the same thing as m_SpecificReg/Type except the value it matches against origniated from an earlier pattern in the same mi_match expression.

This patch also changes how commutative patterns are handled: in order to support m_DefferedReg/Type, we always have to run the LHS-pattern before the RHS one.

This pattern does the same thing as m_SpecificReg/Type except the value
it matches against origniated from an earlier pattern in the same
mi_match expression.

This patch also changes how commutative patterns are handled: in order
to support m_DefferedReg/Type, we always have to run the LHS-pattern
before the RHS one.
@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: Min-Yih Hsu (mshockwave)

Changes

This pattern does the same thing as m_SpecificReg/Type except the value it matches against origniated from an earlier pattern in the same mi_match expression.

This patch also changes how commutative patterns are handled: in order to support m_DefferedReg/Type, we always have to run the LHS-pattern before the RHS one.


Full diff: https://github.com/llvm/llvm-project/pull/121218.diff

2 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h (+48-4)
  • (modified) llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp (+30)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index 47417f53b6e40a..450aaa17ee2665 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -372,6 +372,36 @@ inline bind_ty<LLT> m_Type(LLT &Ty) { return Ty; }
 inline bind_ty<CmpInst::Predicate> m_Pred(CmpInst::Predicate &P) { return P; }
 inline operand_type_match m_Pred() { return operand_type_match(); }
 
+template <typename BindTy> struct deferred_helper {
+  static bool match(const MachineRegisterInfo &MRI, BindTy &VR, BindTy &V) {
+    return VR == V;
+  }
+};
+
+template <> struct deferred_helper<LLT> {
+  static bool match(const MachineRegisterInfo &MRI, LLT VT, Register R) {
+    return VT == MRI.getType(R);
+  }
+};
+
+template <typename Class> struct deferred_ty {
+  Class &VR;
+
+  deferred_ty(Class &V) : VR(V) {}
+
+  template <typename ITy> bool match(const MachineRegisterInfo &MRI, ITy &&V) {
+    return deferred_helper<Class>::match(MRI, VR, V);
+  }
+};
+
+/// Similar to m_SpecificReg/Type, but the specific value to match originated
+/// from an earlier sub-pattern in the same mi_match expression. For example,
+/// we cannot match `(add X, X)` with `m_GAdd(m_Reg(X), m_SpecificReg(X))`
+/// because `X` is not initialized at the time it's passed to `m_SpecificReg`.
+/// Instead, we can use `m_GAdd(m_Reg(x), m_DeferredReg(X))`.
+inline deferred_ty<Register> m_DeferredReg(Register &R) { return R; }
+inline deferred_ty<LLT> m_DeferredType(LLT &Ty) { return Ty; }
+
 struct ImplicitDefMatch {
   bool match(const MachineRegisterInfo &MRI, Register Reg) {
     MachineInstr *TmpMI;
@@ -401,8 +431,13 @@ struct BinaryOp_match {
       if (TmpMI->getOpcode() == Opcode && TmpMI->getNumOperands() == 3) {
         return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
                 R.match(MRI, TmpMI->getOperand(2).getReg())) ||
-               (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
-                               L.match(MRI, TmpMI->getOperand(2).getReg())));
+               // NOTE: When trying the alternative different operand ordering
+               // with a commutative operation, it is imperative to always run
+               // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+               // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
+               // expected.
+               (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
+                               R.match(MRI, TmpMI->getOperand(1).getReg())));
       }
     }
     return false;
@@ -426,8 +461,13 @@ struct BinaryOpc_match {
           TmpMI->getNumOperands() == 3) {
         return (L.match(MRI, TmpMI->getOperand(1).getReg()) &&
                 R.match(MRI, TmpMI->getOperand(2).getReg())) ||
-               (Commutable && (R.match(MRI, TmpMI->getOperand(1).getReg()) &&
-                               L.match(MRI, TmpMI->getOperand(2).getReg())));
+               // NOTE: When trying the alternative different operand ordering
+               // with a commutative operation, it is imperative to always run
+               // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+               // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as
+               // expected.
+               (Commutable && (L.match(MRI, TmpMI->getOperand(2).getReg()) &&
+                               R.match(MRI, TmpMI->getOperand(1).getReg())));
       }
     }
     return false;
@@ -674,6 +714,10 @@ struct CompareOp_match {
     Register RHS = TmpMI->getOperand(3).getReg();
     if (L.match(MRI, LHS) && R.match(MRI, RHS))
       return true;
+    // NOTE: When trying the alternative different operand ordering
+    // with a commutative operation, it is imperative to always run
+    // the LHS sub-pattern  (i.e. `L`) before the RHS sub-pattern
+    // (i.e. `R`). Otherwsie, m_DeferredReg/Type will not work as expected.
     if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) &&
         P.match(MRI, CmpInst::getSwappedPredicate(TmpPred)))
       return true;
diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index fc76d4055722e4..40cd055c1c3f80 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -920,6 +920,36 @@ TEST_F(AArch64GISelMITest, MatchSpecificReg) {
   EXPECT_TRUE(mi_match(Add.getReg(0), *MRI, m_GAdd(m_SpecificReg(Reg), m_Reg())));
 }
 
+TEST_F(AArch64GISelMITest, DeferredMatching) {
+  setUp();
+  if (!TM)
+    GTEST_SKIP();
+  auto s64 = LLT::scalar(64);
+  auto s32 = LLT::scalar(32);
+
+  auto Cst1 = B.buildConstant(s64, 42);
+  auto Cst2 = B.buildConstant(s64, 314);
+  auto Add = B.buildAdd(s64, Cst1, Cst2);
+  auto Sub = B.buildSub(s64, Add, Cst1);
+
+  auto TruncAdd = B.buildTrunc(s32, Add);
+  auto TruncSub = B.buildTrunc(s32, Sub);
+  auto NarrowAdd = B.buildAdd(s32, TruncAdd, TruncSub);
+
+  Register X;
+  EXPECT_TRUE(mi_match(Sub.getReg(0), *MRI,
+                       m_GSub(m_GAdd(m_Reg(X), m_Reg()), m_DeferredReg(X))));
+  LLT Ty;
+  EXPECT_TRUE(
+      mi_match(NarrowAdd.getReg(0), *MRI,
+               m_GAdd(m_GTrunc(m_Type(Ty)), m_GTrunc(m_DeferredType(Ty)))));
+
+  // Test commutative.
+  auto Add2 = B.buildAdd(s64, Sub, Cst1);
+  EXPECT_TRUE(mi_match(Add2.getReg(0), *MRI,
+                       m_GAdd(m_Reg(X), m_GSub(m_Reg(), m_DeferredReg(X)))));
+}
+
 } // namespace
 
 int main(int argc, char **argv) {

@mshockwave mshockwave merged commit a74f825 into llvm:main Dec 30, 2024
8 checks passed
@mshockwave mshockwave deleted the patch/mipatternmatch-deferred branch December 30, 2024 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants