Skip to content

[SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. #122241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 10, 2025

Conversation

HanKuanChen
Copy link
Contributor

@HanKuanChen HanKuanChen commented Jan 9, 2025

Replace Cnt and AltIndex with MainOp and AltOp.
Reduce the number of iterations in the for loop.

@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Han-Kuan Chen (HanKuanChen)

Changes

Replace Cnt and AltIndex with MainOp and AltOp as InstructionsState.
Reduce the number of iterations in the for loop.


Full diff: https://github.com/llvm/llvm-project/pull/122241.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+28-32)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c4582df89213d8..aa12af43dcf95c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -915,24 +915,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   if (It == VL.end())
     return InstructionsState::invalid();
 
-  Value *V = *It;
+  Instruction *MainOp = cast<Instruction>(*It);
   unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
-  if ((VL.size() > 2 && !isa<PHINode>(V) && InstCnt < VL.size() / 2) ||
+  if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
       (VL.size() == 2 && InstCnt < 2))
     return InstructionsState::invalid();
 
-  bool IsCastOp = isa<CastInst>(V);
-  bool IsBinOp = isa<BinaryOperator>(V);
-  bool IsCmpOp = isa<CmpInst>(V);
-  CmpInst::Predicate BasePred =
-      IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
-  unsigned Opcode = cast<Instruction>(V)->getOpcode();
+  bool IsCastOp = isa<CastInst>(MainOp);
+  bool IsBinOp = isa<BinaryOperator>(MainOp);
+  bool IsCmpOp = isa<CmpInst>(MainOp);
+  CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
+                                        : CmpInst::BAD_ICMP_PREDICATE;
+  Instruction *AltOp = MainOp;
+  unsigned Opcode = MainOp->getOpcode();
   unsigned AltOpcode = Opcode;
-  unsigned AltIndex = std::distance(VL.begin(), It);
 
-  bool SwappedPredsCompatible = [&]() {
-    if (!IsCmpOp)
-      return false;
+  bool SwappedPredsCompatible = IsCmpOp && [&]() {
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
     UniquePreds.insert(BasePred);
     UniqueNonSwappedPreds.insert(BasePred);
@@ -955,18 +953,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   }();
   // Check for one alternate opcode from another BinaryOperator.
   // TODO - generalize to support all operators (types, calls etc.).
-  auto *IBase = cast<Instruction>(V);
   Intrinsic::ID BaseID = 0;
   SmallVector<VFInfo> BaseMappings;
-  if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
+  if (auto *CallBase = dyn_cast<CallInst>(MainOp)) {
     BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
     BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
     if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
       return InstructionsState::invalid();
   }
   bool AnyPoison = InstCnt != VL.size();
-  for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
-    auto *I = dyn_cast<Instruction>(VL[Cnt]);
+  // Skip MainOp.
+  while (++It != VL.end()) {
+    auto *I = dyn_cast<Instruction>(*It);
     if (!I)
       continue;
 
@@ -982,11 +980,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
           isValidForAlternation(Opcode)) {
         AltOpcode = InstOpcode;
-        AltIndex = Cnt;
+        AltOp = I;
         continue;
       }
     } else if (IsCastOp && isa<CastInst>(I)) {
-      Value *Op0 = IBase->getOperand(0);
+      Value *Op0 = MainOp->getOperand(0);
       Type *Ty0 = Op0->getType();
       Value *Op1 = I->getOperand(0);
       Type *Ty1 = Op1->getType();
@@ -998,12 +996,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                  isValidForAlternation(InstOpcode) &&
                  "Cast isn't safe for alternation, logic needs to be updated!");
           AltOpcode = InstOpcode;
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
       }
-    } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
-      auto *BaseInst = cast<CmpInst>(V);
+    } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
+      auto *BaseInst = cast<CmpInst>(MainOp);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
@@ -1017,24 +1015,23 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         CmpInst::Predicate SwappedCurrentPred =
             CmpInst::getSwappedPredicate(CurrentPred);
 
-        if ((E == 2 || SwappedPredsCompatible) &&
+        if ((VL.size() == 2 || SwappedPredsCompatible) &&
             (BasePred == CurrentPred || BasePred == SwappedCurrentPred))
           continue;
 
         if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
           continue;
-        auto *AltInst = cast<CmpInst>(VL[AltIndex]);
-        if (AltIndex) {
-          if (isCmpSameOrSwapped(AltInst, Inst, TLI))
+        if (MainOp != AltOp) {
+          if (isCmpSameOrSwapped(cast<CmpInst>(AltOp), Inst, TLI))
             continue;
         } else if (BasePred != CurrentPred) {
           assert(
               isValidForAlternation(InstOpcode) &&
               "CmpInst isn't safe for alternation, logic needs to be updated!");
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
-        CmpInst::Predicate AltPred = AltInst->getPredicate();
+        CmpInst::Predicate AltPred = cast<CmpInst>(AltOp)->getPredicate();
         if (BasePred == CurrentPred || BasePred == SwappedCurrentPred ||
             AltPred == CurrentPred || AltPred == SwappedCurrentPred)
           continue;
@@ -1045,17 +1042,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
              "CastInst.");
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
-            Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
+            Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
           return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
           return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
-        auto *BaseLI = cast<LoadInst>(IBase);
+        auto *BaseLI = cast<LoadInst>(MainOp);
         if (!LI->isSimple() || !BaseLI->isSimple())
           return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
-        auto *CallBase = cast<CallInst>(IBase);
+        auto *CallBase = cast<CallInst>(MainOp);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
           return InstructionsState::invalid();
         if (Call->hasOperandBundles() &&
@@ -1085,8 +1082,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
   }
 
-  return InstructionsState(cast<Instruction>(V),
-                           cast<Instruction>(VL[AltIndex]));
+  return InstructionsState(MainOp, cast<Instruction>(AltOp));
 }
 
 /// \returns true if all of the values in \p VL have the same type or false

Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@HanKuanChen HanKuanChen merged commit 36b423e into llvm:main Jan 10, 2025
6 of 8 checks passed
@HanKuanChen HanKuanChen deleted the slp-getSameOpcode branch January 10, 2025 01:06
@vitalybuka
Copy link
Collaborator

Looks like it's crashing bots https://lab.llvm.org/buildbot/#/builders/55/builds/5319
CC @fmayer

BaiXilin pushed a commit to BaiXilin/llvm-fix-vnni-instr-types that referenced this pull request Jan 12, 2025
…vm#122241)

Replace Cnt and AltIndex with MainOp and AltOp.
Reduce the number of iterations in the for loop.
@mikaelholmen
Copy link
Collaborator

Hi @HanKuanChen and @alexey-bataev ,

The following starts crashing with this "NFC" patch:

opt -passes="slp-vectorizer" bbi-103405.ll -S -o /dev/null -mtriple=thumb7 -mcpu=swift

bbi-103405.ll.gz

@alexey-bataev
Copy link
Member

Hi @HanKuanChen and @alexey-bataev ,

The following starts crashing with this "NFC" patch:

opt -passes="slp-vectorizer" bbi-103405.ll -S -o /dev/null -mtriple=thumb7 -mcpu=swift

bbi-103405.ll.gz

Fixed in a1ab5b4

@mikaelholmen
Copy link
Collaborator

Hi @HanKuanChen and @alexey-bataev ,
The following starts crashing with this "NFC" patch:

opt -passes="slp-vectorizer" bbi-103405.ll -S -o /dev/null -mtriple=thumb7 -mcpu=swift

bbi-103405.ll.gz

Fixed in a1ab5b4

I've verified that the crash goes away with the fix.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants