Skip to content

[SLP][REVEC] Fix the mismatch between the result of getAltInstrMask and the VecTy argument of TargetTransformInfo::isLegalAltInstr. #134795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,9 +1264,8 @@ static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) {

/// \returns a bitset for selecting opcodes. false for Opcode0 and true for
/// Opcode1.
static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0,
unsigned Opcode1) {
Type *ScalarTy = VL[0]->getType();
static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
unsigned Opcode0, unsigned Opcode1) {
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
SmallBitVector OpcodeMask(VL.size() * ScalarTyNumElements, false);
for (unsigned Lane : seq<unsigned>(VL.size())) {
Expand Down Expand Up @@ -6667,11 +6666,12 @@ void BoUpSLP::reorderTopToBottom() {
// to take into account their order when looking for the most used order.
if (TE->hasState() && TE->isAltShuffle() &&
TE->State != TreeEntry::SplitVectorize) {
VectorType *VecTy =
getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size());
Type *ScalarTy = TE->Scalars[0]->getType();
VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
Comment on lines +6669 to +6670
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better not to hide ScalarTy, use another name for the var

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarTy makes sense, doesn't it?

  1. It describes the type of TE->Scalars[0].
  2. ScalarTy hasn't appeared earlier.
  3. The scope of the if statement is small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, missed that it is another function

unsigned Opcode0 = TE->getOpcode();
unsigned Opcode1 = TE->getAltOpcode();
SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1));
SmallBitVector OpcodeMask(
getAltInstrMask(TE->Scalars, ScalarTy, Opcode0, Opcode1));
// If this pattern is supported by the target then we consider the order.
if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
Expand Down Expand Up @@ -8352,12 +8352,13 @@ static bool isAlternateInstruction(const Instruction *I,

bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
ArrayRef<Value *> VL) const {
Type *ScalarTy = S.getMainOp()->getType();
unsigned Opcode0 = S.getOpcode();
unsigned Opcode1 = S.getAltOpcode();
SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
// If this pattern is supported by the target then consider it profitable.
if (TTI->isLegalAltInstr(getWidenedType(S.getMainOp()->getType(), VL.size()),
Opcode0, Opcode1, OpcodeMask))
if (TTI->isLegalAltInstr(getWidenedType(ScalarTy, VL.size()), Opcode0,
Opcode1, OpcodeMask))
return true;
SmallVector<ValueList> Operands;
for (unsigned I : seq<unsigned>(S.getMainOp()->getNumOperands())) {
Expand Down Expand Up @@ -9270,7 +9271,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
unsigned Opcode0 = LocalState.getOpcode();
unsigned Opcode1 = LocalState.getAltOpcode();
SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
// Enable split node, only if all nodes do not form legal alternate
// instruction (like X86 addsub).
SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
Expand Down Expand Up @@ -13200,7 +13201,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
// order.
unsigned Opcode0 = E->getOpcode();
unsigned Opcode1 = E->getAltOpcode();
SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1));
SmallBitVector OpcodeMask(
getAltInstrMask(E->Scalars, ScalarTy, Opcode0, Opcode1));
// If this pattern is supported by the target then we consider the
// order.
if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
Expand Down
47 changes: 47 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/X86/revec-getAltInstrMask.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec %s | FileCheck %s

define i32 @test() {
; CHECK-LABEL: @test(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr null, align 1
; CHECK-NEXT: [[WIDE_LOAD136:%.*]] = load <16 x i8>, ptr null, align 1
; CHECK-NEXT: [[WIDE_LOAD137:%.*]] = load <16 x i8>, ptr null, align 1
; CHECK-NEXT: [[WIDE_LOAD138:%.*]] = load <16 x i8>, ptr null, align 1
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i8> [[WIDE_LOAD]], zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[WIDE_LOAD136]], zeroinitializer
; CHECK-NEXT: [[TMP4:%.*]] = or <16 x i8> [[WIDE_LOAD137]], zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i8> [[WIDE_LOAD138]], zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = icmp ult <16 x i8> [[TMP2]], zeroinitializer
; CHECK-NEXT: [[TMP7:%.*]] = icmp ult <16 x i8> [[TMP3]], zeroinitializer
; CHECK-NEXT: [[TMP8:%.*]] = icmp ult <16 x i8> [[TMP4]], zeroinitializer
; CHECK-NEXT: [[TMP9:%.*]] = icmp ult <16 x i8> [[TMP5]], zeroinitializer
; CHECK-NEXT: [[TMP10:%.*]] = or <16 x i8> [[TMP0]], zeroinitializer
; CHECK-NEXT: [[TMP11:%.*]] = or <16 x i8> [[TMP1]], zeroinitializer
; CHECK-NEXT: [[TMP12:%.*]] = icmp ult <16 x i8> [[TMP10]], zeroinitializer
; CHECK-NEXT: [[TMP13:%.*]] = icmp ult <16 x i8> [[TMP11]], zeroinitializer
; CHECK-NEXT: ret i32 0
;
entry:
%wide.load = load <16 x i8>, ptr null, align 1
%wide.load136 = load <16 x i8>, ptr null, align 1
%wide.load137 = load <16 x i8>, ptr null, align 1
%wide.load138 = load <16 x i8>, ptr null, align 1
%0 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
%1 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
%2 = or <16 x i8> %wide.load, zeroinitializer
%3 = or <16 x i8> %wide.load136, zeroinitializer
%4 = or <16 x i8> %wide.load137, zeroinitializer
%5 = or <16 x i8> %wide.load138, zeroinitializer
%6 = icmp ult <16 x i8> %2, zeroinitializer
%7 = icmp ult <16 x i8> %3, zeroinitializer
%8 = icmp ult <16 x i8> %4, zeroinitializer
%9 = icmp ult <16 x i8> %5, zeroinitializer
%10 = or <16 x i8> %0, zeroinitializer
%11 = or <16 x i8> %1, zeroinitializer
%12 = icmp ult <16 x i8> %10, zeroinitializer
%13 = icmp ult <16 x i8> %11, zeroinitializer
ret i32 0
}