Skip to content

[SLP][REVEC] Fix the mismatch between the result of getAltInstrMask and the VecTy argument of TargetTransformInfo::isLegalAltInstr. #134795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2025

Conversation

HanKuanChen
Copy link
Contributor

We cannot determine ScalarTy from VL because some ScalarTy is determined
from VL[0]->getType(), while others are determined from
getValueType(VL[0]).

Fix "Mask and VecTy are incompatible".

the VecTy argument of TargetTransformInfo::isLegalAltInstr.

We cannot determine ScalarTy from VL because some ScalarTy is determined
from VL[0]->getType(), while others are determined from
getValueType(VL[0]).

Fix "Mask and VecTy are incompatible".
@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Han-Kuan Chen (HanKuanChen)

Changes

We cannot determine ScalarTy from VL because some ScalarTy is determined
from VL[0]->getType(), while others are determined from
getValueType(VL[0]).

Fix "Mask and VecTy are incompatible".


Full diff: https://github.com/llvm/llvm-project/pull/134795.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+13-11)
  • (added) llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll (+47)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e6559f26be8c2..7e167f238b82e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1264,9 +1264,8 @@ static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) {
 
 /// \returns a bitset for selecting opcodes. false for Opcode0 and true for
 /// Opcode1.
-static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0,
-                                      unsigned Opcode1) {
-  Type *ScalarTy = VL[0]->getType();
+static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
+                                      unsigned Opcode0, unsigned Opcode1) {
   unsigned ScalarTyNumElements = getNumElements(ScalarTy);
   SmallBitVector OpcodeMask(VL.size() * ScalarTyNumElements, false);
   for (unsigned Lane : seq<unsigned>(VL.size())) {
@@ -6667,11 +6666,12 @@ void BoUpSLP::reorderTopToBottom() {
     // to take into account their order when looking for the most used order.
     if (TE->hasState() && TE->isAltShuffle() &&
         TE->State != TreeEntry::SplitVectorize) {
-      VectorType *VecTy =
-          getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size());
+      Type *ScalarTy = TE->Scalars[0]->getType();
+      VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
       unsigned Opcode0 = TE->getOpcode();
       unsigned Opcode1 = TE->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(TE->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
         VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
@@ -8352,12 +8352,13 @@ static bool isAlternateInstruction(const Instruction *I,
 
 bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
                                        ArrayRef<Value *> VL) const {
+  Type *ScalarTy = S.getMainOp()->getType();
   unsigned Opcode0 = S.getOpcode();
   unsigned Opcode1 = S.getAltOpcode();
-  SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+  SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
   // If this pattern is supported by the target then consider it profitable.
-  if (TTI->isLegalAltInstr(getWidenedType(S.getMainOp()->getType(), VL.size()),
-                           Opcode0, Opcode1, OpcodeMask))
+  if (TTI->isLegalAltInstr(getWidenedType(ScalarTy, VL.size()), Opcode0,
+                           Opcode1, OpcodeMask))
     return true;
   SmallVector<ValueList> Operands;
   for (unsigned I : seq<unsigned>(S.getMainOp()->getNumOperands())) {
@@ -9270,7 +9271,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
     unsigned Opcode0 = LocalState.getOpcode();
     unsigned Opcode1 = LocalState.getAltOpcode();
-    SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+    SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
     // Enable split node, only if all nodes do not form legal alternate
     // instruction (like X86 addsub).
     SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
@@ -13200,7 +13201,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       // order.
       unsigned Opcode0 = E->getOpcode();
       unsigned Opcode1 = E->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(E->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the
       // order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
new file mode 100644
index 0000000000000..8380b1cb5f850
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec %s | FileCheck %s
+
+define i32 @test() {
+; CHECK-LABEL: @test(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD136:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD137:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD138:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = or <16 x i8> [[WIDE_LOAD]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[WIDE_LOAD136]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = or <16 x i8> [[WIDE_LOAD137]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = or <16 x i8> [[WIDE_LOAD138]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ult <16 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult <16 x i8> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp ult <16 x i8> [[TMP4]], zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ult <16 x i8> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = or <16 x i8> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = or <16 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ult <16 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ult <16 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %wide.load = load <16 x i8>, ptr null, align 1
+  %wide.load136 = load <16 x i8>, ptr null, align 1
+  %wide.load137 = load <16 x i8>, ptr null, align 1
+  %wide.load138 = load <16 x i8>, ptr null, align 1
+  %0 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %1 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %2 = or <16 x i8> %wide.load, zeroinitializer
+  %3 = or <16 x i8> %wide.load136, zeroinitializer
+  %4 = or <16 x i8> %wide.load137, zeroinitializer
+  %5 = or <16 x i8> %wide.load138, zeroinitializer
+  %6 = icmp ult <16 x i8> %2, zeroinitializer
+  %7 = icmp ult <16 x i8> %3, zeroinitializer
+  %8 = icmp ult <16 x i8> %4, zeroinitializer
+  %9 = icmp ult <16 x i8> %5, zeroinitializer
+  %10 = or <16 x i8> %0, zeroinitializer
+  %11 = or <16 x i8> %1, zeroinitializer
+  %12 = icmp ult <16 x i8> %10, zeroinitializer
+  %13 = icmp ult <16 x i8> %11, zeroinitializer
+  ret i32 0
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 8, 2025

@llvm/pr-subscribers-vectorizers

Author: Han-Kuan Chen (HanKuanChen)

Changes

We cannot determine ScalarTy from VL because some ScalarTy is determined
from VL[0]->getType(), while others are determined from
getValueType(VL[0]).

Fix "Mask and VecTy are incompatible".


Full diff: https://github.com/llvm/llvm-project/pull/134795.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+13-11)
  • (added) llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll (+47)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e6559f26be8c2..7e167f238b82e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1264,9 +1264,8 @@ static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) {
 
 /// \returns a bitset for selecting opcodes. false for Opcode0 and true for
 /// Opcode1.
-static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0,
-                                      unsigned Opcode1) {
-  Type *ScalarTy = VL[0]->getType();
+static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
+                                      unsigned Opcode0, unsigned Opcode1) {
   unsigned ScalarTyNumElements = getNumElements(ScalarTy);
   SmallBitVector OpcodeMask(VL.size() * ScalarTyNumElements, false);
   for (unsigned Lane : seq<unsigned>(VL.size())) {
@@ -6667,11 +6666,12 @@ void BoUpSLP::reorderTopToBottom() {
     // to take into account their order when looking for the most used order.
     if (TE->hasState() && TE->isAltShuffle() &&
         TE->State != TreeEntry::SplitVectorize) {
-      VectorType *VecTy =
-          getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size());
+      Type *ScalarTy = TE->Scalars[0]->getType();
+      VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
       unsigned Opcode0 = TE->getOpcode();
       unsigned Opcode1 = TE->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(TE->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
         VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
@@ -8352,12 +8352,13 @@ static bool isAlternateInstruction(const Instruction *I,
 
 bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
                                        ArrayRef<Value *> VL) const {
+  Type *ScalarTy = S.getMainOp()->getType();
   unsigned Opcode0 = S.getOpcode();
   unsigned Opcode1 = S.getAltOpcode();
-  SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+  SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
   // If this pattern is supported by the target then consider it profitable.
-  if (TTI->isLegalAltInstr(getWidenedType(S.getMainOp()->getType(), VL.size()),
-                           Opcode0, Opcode1, OpcodeMask))
+  if (TTI->isLegalAltInstr(getWidenedType(ScalarTy, VL.size()), Opcode0,
+                           Opcode1, OpcodeMask))
     return true;
   SmallVector<ValueList> Operands;
   for (unsigned I : seq<unsigned>(S.getMainOp()->getNumOperands())) {
@@ -9270,7 +9271,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
     VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
     unsigned Opcode0 = LocalState.getOpcode();
     unsigned Opcode1 = LocalState.getAltOpcode();
-    SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
+    SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
     // Enable split node, only if all nodes do not form legal alternate
     // instruction (like X86 addsub).
     SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
@@ -13200,7 +13201,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       // order.
       unsigned Opcode0 = E->getOpcode();
       unsigned Opcode1 = E->getAltOpcode();
-      SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1));
+      SmallBitVector OpcodeMask(
+          getAltInstrMask(E->Scalars, ScalarTy, Opcode0, Opcode1));
       // If this pattern is supported by the target then we consider the
       // order.
       if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
new file mode 100644
index 0000000000000..8380b1cb5f850
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec %s | FileCheck %s
+
+define i32 @test() {
+; CHECK-LABEL: @test(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD136:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD137:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[WIDE_LOAD138:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = or <16 x i8> [[WIDE_LOAD]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = or <16 x i8> [[WIDE_LOAD136]], zeroinitializer
+; CHECK-NEXT:    [[TMP4:%.*]] = or <16 x i8> [[WIDE_LOAD137]], zeroinitializer
+; CHECK-NEXT:    [[TMP5:%.*]] = or <16 x i8> [[WIDE_LOAD138]], zeroinitializer
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp ult <16 x i8> [[TMP2]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult <16 x i8> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp ult <16 x i8> [[TMP4]], zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp ult <16 x i8> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = or <16 x i8> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = or <16 x i8> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ult <16 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ult <16 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %wide.load = load <16 x i8>, ptr null, align 1
+  %wide.load136 = load <16 x i8>, ptr null, align 1
+  %wide.load137 = load <16 x i8>, ptr null, align 1
+  %wide.load138 = load <16 x i8>, ptr null, align 1
+  %0 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %1 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
+  %2 = or <16 x i8> %wide.load, zeroinitializer
+  %3 = or <16 x i8> %wide.load136, zeroinitializer
+  %4 = or <16 x i8> %wide.load137, zeroinitializer
+  %5 = or <16 x i8> %wide.load138, zeroinitializer
+  %6 = icmp ult <16 x i8> %2, zeroinitializer
+  %7 = icmp ult <16 x i8> %3, zeroinitializer
+  %8 = icmp ult <16 x i8> %4, zeroinitializer
+  %9 = icmp ult <16 x i8> %5, zeroinitializer
+  %10 = or <16 x i8> %0, zeroinitializer
+  %11 = or <16 x i8> %1, zeroinitializer
+  %12 = icmp ult <16 x i8> %10, zeroinitializer
+  %13 = icmp ult <16 x i8> %11, zeroinitializer
+  ret i32 0
+}

Comment on lines +6669 to +6670
Type *ScalarTy = TE->Scalars[0]->getType();
VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better not to hide ScalarTy, use another name for the var

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarTy makes sense, doesn't it?

  1. It describes the type of TE->Scalars[0].
  2. ScalarTy hasn't appeared earlier.
  3. The scope of the if statement is small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, missed that it is another function

@HanKuanChen HanKuanChen merged commit 2347aa1 into llvm:main Apr 8, 2025
14 checks passed
@HanKuanChen HanKuanChen deleted the slp-revec-getAltInstrMask branch April 8, 2025 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants