Skip to content

[VectorCombine] Handle shuffle of selects #128032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 6, 2025
Merged

Conversation

ParkHanbum
Copy link
Contributor

(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m)
-> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))

The behaviour of SelectInst on vectors is the same as for
V'select[i] = Condition[i] ? V'True[i] : V'False[i].

If a ShuffleVector is performed on two selects, it will be like:
V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])

That's why a ShuffleVector with two SelectInst is equivalent to
first ShuffleVector Condition/True/False and then SelectInst that
result.

This patch implements the transforming described above.

Proof: https://alive2.llvm.org/ce/z/97wfHp
Fixed: #120775

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: hanbeom (ParkHanbum)

Changes

(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m)
-> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))

The behaviour of SelectInst on vectors is the same as for
V'select[i] = Condition[i] ? V'True[i] : V'False[i].

If a ShuffleVector is performed on two selects, it will be like:
V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])

That's why a ShuffleVector with two SelectInst is equivalent to
first ShuffleVector Condition/True/False and then SelectInst that
result.

This patch implements the transforming described above.

Proof: https://alive2.llvm.org/ce/z/97wfHp
Fixed: #120775


Full diff: https://github.com/llvm/llvm-project/pull/128032.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+51)
  • (modified) llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll (+16-16)
  • (modified) llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll (+11-13)
  • (modified) llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll (+18)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 746742e14d080..9ea62a5d24393 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -119,6 +119,7 @@ class VectorCombine {
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
+  bool foldShuffleOfSelects(Instruction &I);
   bool foldShuffleOfCastops(Instruction &I);
   bool foldShuffleOfShuffles(Instruction &I);
   bool foldShuffleOfIntrinsics(Instruction &I);
@@ -1899,6 +1900,55 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
   return true;
 }
 
+/// Try to convert,
+/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
+/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
+bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
+  ArrayRef<int> Mask;
+  Value *C1, *T1, *F1, *C2, *T2, *F2;
+  if (!match(&I, m_Shuffle(
+                     m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
+                     m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
+                     m_Mask(Mask))))
+    return false;
+
+  auto SelectOp = Instruction::Select;
+
+  auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
+  auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
+  auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
+  auto *T1VecTy = dyn_cast<FixedVectorType>(T1->getType());
+  auto *F1VecTy = dyn_cast<FixedVectorType>(F1->getType());
+
+  if (!C1VecTy || !C2VecTy)
+    return false;
+
+  InstructionCost OldCost = TTI.getCmpSelInstrCost(
+      SelectOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
+  OldCost += TTI.getCmpSelInstrCost(SelectOp, T2->getType(), C2VecTy,
+                                    CmpInst::BAD_ICMP_PREDICATE, CostKind);
+  OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
+                                Mask, CostKind);
+
+  InstructionCost NewCost =
+      TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, C1VecTy, Mask,
+                         CostKind, 0, nullptr, {C1, C2}, &I);
+  NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, T1VecTy,
+                                Mask, CostKind, 0, nullptr, {T1, T2}, &I);
+  NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, F1VecTy,
+                                Mask, CostKind, 0, nullptr, {F1, F2}, &I);
+  NewCost += TTI.getCmpSelInstrCost(SelectOp, T1->getType(), DstVecTy,
+                                    CmpInst::BAD_ICMP_PREDICATE, CostKind);
+
+  Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
+  Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
+  Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
+  Value *NewRes = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
+
+  replaceValue(I, *NewRes);
+  return true;
+}
+
 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
 /// into "castop (shuffle)".
 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
@@ -3352,6 +3402,7 @@ bool VectorCombine::run() {
       case Instruction::ShuffleVector:
         MadeChange |= foldPermuteOfBinops(I);
         MadeChange |= foldShuffleOfBinops(I);
+        MadeChange |= foldShuffleOfSelects(I);
         MadeChange |= foldShuffleOfCastops(I);
         MadeChange |= foldShuffleOfShuffles(I);
         MadeChange |= foldShuffleOfIntrinsics(I);
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
index c2ed7b9c84523..7c9baf7786733 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll
@@ -87,11 +87,11 @@ define <4 x i64> @x86_pblendvb_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b, <4 x i64>
 ; CHECK-LABEL: @x86_pblendvb_v8i32_v4i32(
 ; CHECK-NEXT:    [[C_BC:%.*]] = bitcast <4 x i64> [[C:%.*]] to <8 x i32>
 ; CHECK-NEXT:    [[D_BC:%.*]] = bitcast <4 x i64> [[D:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <8 x i32> [[C_BC]], [[D_BC]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <8 x i32> [[C_BC]], [[D_BC]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[TMP2]], <8 x i32> [[TMP1]]
-; CHECK-NEXT:    [[RES:%.*]] = bitcast <8 x i32> [[TMP3]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP2]], <8 x i32> [[TMP3]]
+; CHECK-NEXT:    [[RES:%.*]] = bitcast <8 x i32> [[TMP4]] to <4 x i64>
 ; CHECK-NEXT:    ret <4 x i64> [[RES]]
 ;
   %a.bc = bitcast <4 x i64> %a to <32 x i8>
@@ -118,11 +118,11 @@ define <4 x i64> @x86_pblendvb_v16i16_v8i16(<4 x i64> %a, <4 x i64> %b, <4 x i64
 ; CHECK-LABEL: @x86_pblendvb_v16i16_v8i16(
 ; CHECK-NEXT:    [[C_BC:%.*]] = bitcast <4 x i64> [[C:%.*]] to <16 x i16>
 ; CHECK-NEXT:    [[D_BC:%.*]] = bitcast <4 x i64> [[D:%.*]] to <16 x i16>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <16 x i16> [[C_BC]], [[D_BC]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i64> [[A:%.*]] to <16 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <16 x i16> [[C_BC]], [[D_BC]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i64> [[B:%.*]] to <16 x i16>
-; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[CMP]], <16 x i16> [[TMP2]], <16 x i16> [[TMP1]]
-; CHECK-NEXT:    [[RES:%.*]] = bitcast <16 x i16> [[TMP3]] to <4 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x i64> [[A:%.*]] to <16 x i16>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> [[TMP1]], <16 x i16> [[TMP2]], <16 x i16> [[TMP3]]
+; CHECK-NEXT:    [[RES:%.*]] = bitcast <16 x i16> [[TMP4]] to <4 x i64>
 ; CHECK-NEXT:    ret <4 x i64> [[RES]]
 ;
   %a.bc = bitcast <4 x i64> %a to <32 x i8>
@@ -255,11 +255,11 @@ define <8 x i64> @x86_pblendvb_v16i32_v8i32(<8 x i64> %a, <8 x i64> %b, <8 x i64
 ; CHECK-LABEL: @x86_pblendvb_v16i32_v8i32(
 ; CHECK-NEXT:    [[C_BC:%.*]] = bitcast <8 x i64> [[C:%.*]] to <16 x i32>
 ; CHECK-NEXT:    [[D_BC:%.*]] = bitcast <8 x i64> [[D:%.*]] to <16 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <16 x i32> [[C_BC]], [[D_BC]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i64> [[A:%.*]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <16 x i32> [[C_BC]], [[D_BC]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i64> [[B:%.*]] to <16 x i32>
-; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[CMP]], <16 x i32> [[TMP2]], <16 x i32> [[TMP1]]
-; CHECK-NEXT:    [[RES:%.*]] = bitcast <16 x i32> [[TMP3]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <8 x i64> [[A:%.*]] to <16 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <16 x i1> [[TMP1]], <16 x i32> [[TMP2]], <16 x i32> [[TMP3]]
+; CHECK-NEXT:    [[RES:%.*]] = bitcast <16 x i32> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    ret <8 x i64> [[RES]]
 ;
   %a.bc = bitcast <8 x i64> %a to <64 x i8>
@@ -286,11 +286,11 @@ define <8 x i64> @x86_pblendvb_v32i16_v16i16(<8 x i64> %a, <8 x i64> %b, <8 x i6
 ; CHECK-LABEL: @x86_pblendvb_v32i16_v16i16(
 ; CHECK-NEXT:    [[C_BC:%.*]] = bitcast <8 x i64> [[C:%.*]] to <32 x i16>
 ; CHECK-NEXT:    [[D_BC:%.*]] = bitcast <8 x i64> [[D:%.*]] to <32 x i16>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <32 x i16> [[C_BC]], [[D_BC]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i64> [[A:%.*]] to <32 x i16>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <32 x i16> [[C_BC]], [[D_BC]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <8 x i64> [[B:%.*]] to <32 x i16>
-; CHECK-NEXT:    [[TMP3:%.*]] = select <32 x i1> [[CMP]], <32 x i16> [[TMP2]], <32 x i16> [[TMP1]]
-; CHECK-NEXT:    [[RES:%.*]] = bitcast <32 x i16> [[TMP3]] to <8 x i64>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <8 x i64> [[A:%.*]] to <32 x i16>
+; CHECK-NEXT:    [[TMP4:%.*]] = select <32 x i1> [[TMP1]], <32 x i16> [[TMP2]], <32 x i16> [[TMP3]]
+; CHECK-NEXT:    [[RES:%.*]] = bitcast <32 x i16> [[TMP4]] to <8 x i64>
 ; CHECK-NEXT:    ret <8 x i64> [[RES]]
 ;
   %a.bc = bitcast <8 x i64> %a to <64 x i8>
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
index 09875c5e0af40..7fc348a1ad9c4 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shuffletoidentity.ll
@@ -451,19 +451,14 @@ define <8 x i8> @icmpsel(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 
 define <8 x i8> @icmpsel_diffentcond(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d) {
 ; CHECK-LABEL: @icmpsel_diffentcond(
-; CHECK-NEXT:    [[AB:%.*]] = shufflevector <8 x i8> [[A:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[AT:%.*]] = shufflevector <8 x i8> [[A]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[BB:%.*]] = shufflevector <8 x i8> [[B:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
-; CHECK-NEXT:    [[BT:%.*]] = shufflevector <8 x i8> [[B]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
 ; CHECK-NEXT:    [[CB:%.*]] = shufflevector <8 x i8> [[C:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[CT:%.*]] = shufflevector <8 x i8> [[C]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
 ; CHECK-NEXT:    [[DB:%.*]] = shufflevector <8 x i8> [[D:%.*]], <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[DT:%.*]] = shufflevector <8 x i8> [[D]], <8 x i8> poison, <4 x i32> <i32 7, i32 6, i32 5, i32 4>
-; CHECK-NEXT:    [[ABT1:%.*]] = icmp slt <4 x i8> [[AT]], [[BT]]
-; CHECK-NEXT:    [[ABB1:%.*]] = icmp ult <4 x i8> [[AB]], [[BB]]
-; CHECK-NEXT:    [[ABT:%.*]] = select <4 x i1> [[ABT1]], <4 x i8> [[CT]], <4 x i8> [[DT]]
-; CHECK-NEXT:    [[ABB:%.*]] = select <4 x i1> [[ABB1]], <4 x i8> [[CB]], <4 x i8> [[DB]]
-; CHECK-NEXT:    [[R:%.*]] = shufflevector <4 x i8> [[ABT]], <4 x i8> [[ABB]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[ABT1:%.*]] = icmp slt <4 x i8> [[CT]], [[DT]]
+; CHECK-NEXT:    [[ABB1:%.*]] = icmp ult <4 x i8> [[CB]], [[DB]]
+; CHECK-NEXT:    [[SHUFFLE_CMP:%.*]] = shufflevector <4 x i1> [[ABT1]], <4 x i1> [[ABB1]], <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
+; CHECK-NEXT:    [[R:%.*]] = select <8 x i1> [[SHUFFLE_CMP]], <8 x i8> [[C1:%.*]], <8 x i8> [[D1:%.*]]
 ; CHECK-NEXT:    ret <8 x i8> [[R]]
 ;
   %ab = shufflevector <8 x i8> %a, <8 x i8> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
@@ -992,14 +987,17 @@ define void @maximal_legal_fpmath(ptr %addr1, ptr %addr2, ptr %result, float %va
 }
 
 ; Peek through (repeated) bitcasts to find a common source value.
+; TODO : We can remove the Shufflevector for A, B.
 define <4 x i64> @bitcast_smax_v8i32_v4i32(<4 x i64> %a, <4 x i64> %b) {
 ; CHECK-LABEL: @bitcast_smax_v8i32_v4i32(
 ; CHECK-NEXT:    [[A_BC0:%.*]] = bitcast <4 x i64> [[A:%.*]] to <8 x i32>
 ; CHECK-NEXT:    [[B_BC0:%.*]] = bitcast <4 x i64> [[B:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
-; CHECK-NEXT:    [[A_BC1:%.*]] = bitcast <4 x i64> [[A]] to <8 x i32>
-; CHECK-NEXT:    [[B_BC1:%.*]] = bitcast <4 x i64> [[B]] to <8 x i32>
-; CHECK-NEXT:    [[CONCAT:%.*]] = select <8 x i1> [[CMP]], <8 x i32> [[B_BC1]], <8 x i32> [[A_BC1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt <8 x i32> [[A_BC0]], [[B_BC0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i64> [[B]], <4 x i64> [[B]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <4 x i64> [[TMP2]] to <8 x i32>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i64> [[A]], <4 x i64> [[A]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast <4 x i64> [[TMP4]] to <8 x i32>
+; CHECK-NEXT:    [[CONCAT:%.*]] = select <8 x i1> [[TMP1]], <8 x i32> [[TMP3]], <8 x i32> [[TMP5]]
 ; CHECK-NEXT:    [[RES:%.*]] = bitcast <8 x i32> [[CONCAT]] to <4 x i64>
 ; CHECK-NEXT:    ret <4 x i64> [[RES]]
 ;
diff --git a/llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll b/llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll
index 685d661ea6bcd..70e25cab5845b 100644
--- a/llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/select-shuffle.ll
@@ -36,3 +36,21 @@ end:
   %t5 = shufflevector <4 x double> %t3, <4 x double> %t4, <4 x i32> <i32 0, i32 1, i32 6, i32 7>
   ret <4 x double> %t5
 }
+
+define <8 x i16> @shuffle_select_select(<4 x i16> %x, <4 x i16> %y, <4 x i16> %z) {
+; CHECK-LABEL: @shuffle_select_select(
+; CHECK-NEXT:    [[CMP_XY:%.*]] = icmp slt <4 x i16> [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[CMP_YZ:%.*]] = icmp slt <4 x i16> [[Y]], [[Z:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i1> [[CMP_XY]], <4 x i1> [[CMP_YZ]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i16> [[X]], <4 x i16> [[Y]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <4 x i16> [[Z]], <4 x i16> [[X]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[RES:%.*]] = select <8 x i1> [[TMP1]], <8 x i16> [[TMP2]], <8 x i16> [[TMP3]]
+; CHECK-NEXT:    ret <8 x i16> [[RES]]
+;
+  %cmp.xy = icmp slt <4 x i16> %x, %y
+  %cmp.yz = icmp slt <4 x i16> %y, %z
+  %select.xz = select <4 x i1> %cmp.xy, <4 x i16> %x, <4 x i16> %z
+  %select.yx = select <4 x i1> %cmp.yz, <4 x i16> %y, <4 x i16> %x
+  %res = shufflevector <4 x i16> %select.xz, <4 x i16> %select.yx, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  ret <8 x i16> %res
+}

@ParkHanbum
Copy link
Contributor Author

@RKSimon would you please review this?

TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, C1VecTy, Mask,
CostKind, 0, nullptr, {C1, C2}, &I);
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, T1VecTy,
Mask, CostKind, 0, nullptr, {T1, T2}, &I);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't use &I here - this is a new instruction - it should be on OldCost shuffle (and you can add similar instructions to to the OldCost getCmpSelInstrCost calls as well

Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
Value *NewRes = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);

replaceValue(I, *NewRes);
Copy link
Collaborator

@RKSimon RKSimon Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you call Worklist.pushValue for each new shuffle - and BEFORE the replaceValue

Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
Value *NewRes = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select can hold FMF flags - so we need to do an intersect of the IR flags from the original selects here (plus some tests)

SelectOp, T1->getType(), C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getCmpSelInstrCost(SelectOp, T2->getType(), C2VecTy,
CmpInst::BAD_ICMP_PREDICATE, CostKind);
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, DstVecTy,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this use getInstructionCost(I)? It can be more accurate to know the instruction.

Copy link

github-actions bot commented Feb 22, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ParkHanbum ParkHanbum force-pushed the vector_combine7 branch 2 times, most recently from e6b56c9 to abcad90 Compare February 22, 2025 20:47
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m)
-> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))

The behaviour of SelectInst on vectors is the same as for
`V'select[i] = Condition[i] ? V'True[i] : V'False[i]`.

If a ShuffleVector is performed on two selects, it will be like:
`V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])`

That's why a ShuffleVector with two SelectInst is equivalent to
first ShuffleVector Condition/True/False and then SelectInst that
result.

This patch implements the transforming described above.

Proof: https://alive2.llvm.org/ce/z/97wfHp
Fixed: llvm#120775
@ParkHanbum
Copy link
Contributor Author

@davemgreen The updated code did not affect the test of aarch64. Sorry for the trouble caused by my mistake

if (isa<FPMathOperator>(NewSel))
cast<Instruction>(NewSel)->setFastMathFlags(Select0->getFastMathFlags());

Worklist.pushValue(NewSel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push the new shuffles as well

@ParkHanbum ParkHanbum requested a review from RKSimon February 25, 2025 18:51
@ParkHanbum ParkHanbum requested a review from RKSimon February 26, 2025 18:55
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
if (!C1VecTy || !C2VecTy)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can C1VecTy != C2VecTy?

auto *Select0 = cast<Instruction>(I.getOperand(0));
if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
if (auto *SI1FOp = dyn_cast<FPMathOperator>((I.getOperand(1))))
if (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that both/neither of the selects are FPMathOperator ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, since we're trying to combine two Select statements, I thought they should have the same FMF. do I need to think more about FMFs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably just replace it with:

if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
  if (SI0FOp->getFastMathFlags() != cast<FPMathOperator>((I.getOperand(1)))->getFastMathFlags())
    return false;

Worklist.pushValue(ShuffleCmp);
Worklist.pushValue(ShuffleTrue);
Worklist.pushValue(ShuffleFalse);
Worklist.pushValue(NewSel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to pushValue(NewSel) - replaceValue should handle this for us

@ParkHanbum ParkHanbum requested a review from RKSimon February 27, 2025 22:28
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a couple of final minors

auto *Select0 = cast<Instruction>(I.getOperand(0));
if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
if (auto *SI1FOp = dyn_cast<FPMathOperator>((I.getOperand(1))))
if (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably just replace it with:

if (auto *SI0FOp = dyn_cast<FPMathOperator>(Select0))
  if (SI0FOp->getFastMathFlags() != cast<FPMathOperator>((I.getOperand(1)))->getFastMathFlags())
    return false;

Value *NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);

// We presuppose that the SelectInsts have the same FMF.
if (isa<FPMathOperator>(NewSel))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid similar casts:

if (auto *SIFOp = dyn_cast<FPMathOperator>(NewSel))
  SIFOp->setFastMathFlags(Select0->getFastMathFlags());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic was defective, so I corrected it... sorry abou it.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RKSimon RKSimon merged commit 5d1029b into llvm:main Mar 6, 2025
11 checks passed
@dyung
Copy link
Collaborator

dyung commented Mar 7, 2025

Hi @ParkHanbum, we started to see an assertion failure in the compiler which I bisected back to this change. I have put the details in #130250, can you take a look?

RKSimon pushed a commit that referenced this pull request Mar 7, 2025
…ects (#130281)

In the previous code (#128032), it specified the destination vector as the
getShuffleCost argument. Because the shuffle mask specifies the indices
of the two vectors specified as elements, the maximum value is twice the
size of the source vector. This causes a problem if the destination
vector is smaller than the source vector and specify an index in the
mask that exceeds the size of the destination vector.

Fix the problem by correcting the previous code, which was using wrong
argument in the Cost calculation.

Fixes #130250
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
(shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m)
-> (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))

The behaviour of SelectInst on vectors is the same as for
`V'select[i] = Condition[i] ? V'True[i] : V'False[i]`.

If a ShuffleVector is performed on two selects, it will be like:
`V'[mask] = (V'select[i] = Condition[i] ? V'True[i] : V'False[i])`

That's why a ShuffleVector with two SelectInst is equivalent to
first ShuffleVector Condition/True/False and then SelectInst that
result.

This patch implements the transforming described above.

Proof: https://alive2.llvm.org/ce/z/97wfHp
Fixes llvm#120775
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
…ects (llvm#130281)

In the previous code (llvm#128032), it specified the destination vector as the
getShuffleCost argument. Because the shuffle mask specifies the indices
of the two vectors specified as elements, the maximum value is twice the
size of the source vector. This causes a problem if the destination
vector is smaller than the source vector and specify an index in the
mask that exceeds the size of the destination vector.

Fix the problem by correcting the previous code, which was using wrong
argument in the Cost calculation.

Fixes llvm#130250
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[VectorCombine] Handle shuffle of selects
5 participants