Skip to content

[X86] Generate kmov for masking integers #120593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Mar 3, 2025

Conversation

abhishek-kaushik22
Copy link
Contributor

When we have an integer used as a bit mask the llvm ir looks something like this

%1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
%cmp1 = icmp ne <16 x i32> %1, zeroinitializer

where .splat is vector containing the mask in all lanes. The assembly generated for this looks like

vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1

where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to k registers using a kmov instruction.

kmovw   %ecx, %k1

This is faster and also reduces code size.

When we have an integer used as a bit mask the llvm ir looks something like this
```
%1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
  %cmp1 = icmp ne <16 x i32> %1, zeroinitializer
```
where `.splat` is vector containing the mask in all lanes.
The assembly generated for this looks like
```
vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1
```
where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to `k` registers using a `kmov` instruction. This is faster and also reduces code size.
@llvmbot
Copy link
Member

llvmbot commented Dec 19, 2024

@llvm/pr-subscribers-backend-x86

Author: None (abhishek-kaushik22)

Changes

When we have an integer used as a bit mask the llvm ir looks something like this

%1 = and &lt;16 x i32&gt; %.splat, &lt;i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768&gt;
%cmp1 = icmp ne &lt;16 x i32&gt; %1, zeroinitializer

where .splat is vector containing the mask in all lanes. The assembly generated for this looks like

vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1

where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to k registers using a kmov instruction.

kmovw   %ecx, %k1

This is faster and also reduces code size.


Full diff: https://github.com/llvm/llvm-project/pull/120593.diff

3 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelDAGToDAG.cpp (+66-13)
  • (added) llvm/test/CodeGen/X86/kmov.ll (+205)
  • (modified) llvm/test/CodeGen/X86/pr78897.ll (+2-2)
diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index bb20e6ecf281b0..8c199a30dfbce7 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -592,7 +592,7 @@ namespace {
     bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB,
                         SDNode *ParentC, SDValue A, SDValue B, SDValue C,
                         uint8_t Imm);
-    bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask);
+    bool tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc, SDValue Mask);
     bool tryMatchBitSelect(SDNode *N);
 
     MachineSDNode *emitPCMPISTR(unsigned ROpc, unsigned MOpc, bool MayFoldLoad,
@@ -4897,10 +4897,10 @@ VPTESTM_CASE(v32i16, WZ##SUFFIX)
 #undef VPTESTM_CASE
 }
 
-// Try to create VPTESTM instruction. If InMask is not null, it will be used
-// to form a masked operation.
-bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
-                                 SDValue InMask) {
+// Try to create VPTESTM or KMOV instruction. If InMask is not null, it will be
+// used to form a masked operation.
+bool X86DAGToDAGISel::tryVPTESTMOrKMOV(SDNode *Root, SDValue Setcc,
+                                       SDValue InMask) {
   assert(Subtarget->hasAVX512() && "Expected AVX512!");
   assert(Setcc.getSimpleValueType().getVectorElementType() == MVT::i1 &&
          "Unexpected VT!");
@@ -4975,12 +4975,70 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
   };
 
+  auto canUseKMOV = [&]() {
+    if (Src0.getOpcode() != X86ISD::VBROADCAST)
+      return false;
+
+    if (Src1.getOpcode() != ISD::LOAD ||
+        Src1.getOperand(1).getOpcode() != X86ISD::Wrapper ||
+        Src1.getOperand(1).getOperand(0).getOpcode() != ISD::TargetConstantPool)
+      return false;
+
+    const auto *ConstPool =
+        dyn_cast<ConstantPoolSDNode>(Src1.getOperand(1).getOperand(0));
+    if (!ConstPool)
+      return false;
+
+    const auto *ConstVec = ConstPool->getConstVal();
+    const auto *ConstVecType = dyn_cast<FixedVectorType>(ConstVec->getType());
+    if (!ConstVecType)
+      return false;
+
+    for (unsigned i = 0, e = ConstVecType->getNumElements(), k = 1; i != e;
+         ++i, k *= 2) {
+      const auto *Element = ConstVec->getAggregateElement(i);
+      if (llvm::isa<llvm::UndefValue>(Element)) {
+        for (unsigned j = i + 1; j != e; ++j) {
+          if (!llvm::isa<llvm::UndefValue>(ConstVec->getAggregateElement(j)))
+            return false;
+        }
+        return i != 0;
+      }
+
+      if (Element->getUniqueInteger() != k) {
+        return false;
+      }
+    }
+
+    return true;
+  };
+
   // We can only fold loads if the sources are unique.
   bool CanFoldLoads = Src0 != Src1;
 
   bool FoldedLoad = false;
   SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
+  SDLoc dl(Root);
+  bool IsTestN = CC == ISD::SETEQ;
+  MachineSDNode *CNode;
+  MVT ResVT = Setcc.getSimpleValueType();
   if (CanFoldLoads) {
+    if (canUseKMOV()) {
+      auto Op = Src0.getOperand(0);
+      if (Op.getSimpleValueType() == MVT::i8) {
+        Op = SDValue(CurDAG->getNode(ISD::ZERO_EXTEND, dl, MVT::i32, Op));
+      }
+      CNode = CurDAG->getMachineNode(
+          ResVT.getVectorNumElements() <= 8 ? X86::KMOVBkr : X86::KMOVWkr, dl,
+          ResVT, Op);
+      if (IsTestN)
+        CNode = CurDAG->getMachineNode(
+            ResVT.getVectorNumElements() <= 8 ? X86::KNOTBkk : X86::KNOTWkk, dl,
+            ResVT, SDValue(CNode, 0));
+      ReplaceUses(SDValue(Root, 0), SDValue(CNode, 0));
+      CurDAG->RemoveDeadNode(Root);
+      return true;
+    }
     FoldedLoad = tryFoldLoadOrBCast(Root, N0.getNode(), Src1, Tmp0, Tmp1, Tmp2,
                                     Tmp3, Tmp4);
     if (!FoldedLoad) {
@@ -4996,9 +5054,6 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
 
   bool IsMasked = InMask.getNode() != nullptr;
 
-  SDLoc dl(Root);
-
-  MVT ResVT = Setcc.getSimpleValueType();
   MVT MaskVT = ResVT;
   if (Widen) {
     // Widen the inputs using insert_subreg or copy_to_regclass.
@@ -5023,11 +5078,9 @@ bool X86DAGToDAGISel::tryVPTESTM(SDNode *Root, SDValue Setcc,
     }
   }
 
-  bool IsTestN = CC == ISD::SETEQ;
   unsigned Opc = getVPTESTMOpc(CmpVT, IsTestN, FoldedLoad, FoldedBCast,
                                IsMasked);
 
-  MachineSDNode *CNode;
   if (FoldedLoad) {
     SDVTList VTs = CurDAG->getVTList(MaskVT, MVT::Other);
 
@@ -5466,10 +5519,10 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
       SDValue N0 = Node->getOperand(0);
       SDValue N1 = Node->getOperand(1);
       if (N0.getOpcode() == ISD::SETCC && N0.hasOneUse() &&
-          tryVPTESTM(Node, N0, N1))
+          tryVPTESTMOrKMOV(Node, N0, N1))
         return;
       if (N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
-          tryVPTESTM(Node, N1, N0))
+          tryVPTESTMOrKMOV(Node, N1, N0))
         return;
     }
 
@@ -6393,7 +6446,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
   }
 
   case ISD::SETCC: {
-    if (NVT.isVector() && tryVPTESTM(Node, SDValue(Node, 0), SDValue()))
+    if (NVT.isVector() && tryVPTESTMOrKMOV(Node, SDValue(Node, 0), SDValue()))
       return;
 
     break;
diff --git a/llvm/test/CodeGen/X86/kmov.ll b/llvm/test/CodeGen/X86/kmov.ll
new file mode 100644
index 00000000000000..6d72a8923c5ab3
--- /dev/null
+++ b/llvm/test/CodeGen/X86/kmov.ll
@@ -0,0 +1,205 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s
+
+define dso_local void @foo_16_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp ne <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+; Function Attrs: mustprogress nounwind uwtable
+define dso_local void @foo_16_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_16_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovw %ecx, %k0
+; CHECK-NEXT:    knotw %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %zmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %zmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vmovups %zmm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <16 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <16 x i32> %.splatinsert, <16 x i32> poison, <16 x i32> zeroinitializer
+  %1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
+  %hir.cmp.45 = icmp eq <16 x i32> %1, zeroinitializer
+  %2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
+  tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp ne <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_8_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_8_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %ymm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %ymm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vmovups %ymm0, (%rdi) {%k1}
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <8 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <8 x i32> %.splatinsert, <8 x i32> poison, <8 x i32> zeroinitializer
+  %1 = and <8 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128>
+  %hir.cmp.45 = icmp eq <8 x i32> %1, zeroinitializer
+  %2 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %b, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %3 = tail call <8 x float> @llvm.masked.load.v8f32.p0(ptr %a, i32 4, <8 x i1> %hir.cmp.45, <8 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <8 x float> %2, %3
+  tail call void @llvm.masked.store.v8f32.p0(<8 x float> %4, ptr %c, i32 4, <8 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp ne <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_4_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_4_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %0 = and i32 %mask, 65535
+  %.splatinsert = insertelement <4 x i32> poison, i32 %0, i64 0
+  %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer
+  %1 = and <4 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8>
+  %hir.cmp.45 = icmp eq <4 x i32> %1, zeroinitializer
+  %2 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %b, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %3 = tail call <4 x float> @llvm.masked.load.v4f32.p0(ptr %a, i32 4, <4 x i1> %hir.cmp.45, <4 x float> poison)
+  %4 = fadd reassoc nsz arcp contract afn <4 x float> %2, %3
+  tail call void @llvm.masked.store.v4f32.p0(<4 x float> %4, ptr %c, i32 4, <4 x i1> %hir.cmp.45)
+  ret void
+}
+
+define dso_local void @foo_2_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_ne:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp ne <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+define dso_local void @foo_2_eq(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
+; CHECK-LABEL: foo_2_eq:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    kmovb %ecx, %k0
+; CHECK-NEXT:    knotb %k0, %k0
+; CHECK-NEXT:    kshiftlb $6, %k0, %k0
+; CHECK-NEXT:    kshiftrb $6, %k0, %k1
+; CHECK-NEXT:    vmovups (%rdx), %xmm0 {%k1} {z}
+; CHECK-NEXT:    vmovups (%rsi), %xmm1 {%k1} {z}
+; CHECK-NEXT:    vaddps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vmovups %xmm0, (%rdi) {%k1}
+; CHECK-NEXT:    retq
+entry:
+  %.splatinsert = insertelement <2 x i32> poison, i32 %mask, i64 0
+  %.splat = shufflevector <2 x i32> %.splatinsert, <2 x i32> poison, <2 x i32> zeroinitializer
+  %0 = and <2 x i32> %.splat, <i32 1, i32 2>
+  %hir.cmp.44 = icmp eq <2 x i32> %0, zeroinitializer
+  %1 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %b, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %2 = tail call <2 x float> @llvm.masked.load.v2f32.p0(ptr %a, i32 4, <2 x i1> %hir.cmp.44, <2 x float> poison)
+  %3 = fadd reassoc nsz arcp contract afn <2 x float> %1, %2
+  tail call void @llvm.masked.store.v2f32.p0(<2 x float> %3, ptr %c, i32 4, <2 x i1> %hir.cmp.44)
+  ret void
+}
+
+declare <2 x float> @llvm.masked.load.v2f32.p0(ptr nocapture, i32 immarg, <2 x i1>, <2 x float>) #1
+
+declare void @llvm.masked.store.v2f32.p0(<2 x float>, ptr nocapture, i32 immarg, <2 x i1>) #2
+
+declare <4 x float> @llvm.masked.load.v4f32.p0(ptr nocapture, i32 immarg, <4 x i1>, <4 x float>) #1
+
+declare void @llvm.masked.store.v4f32.p0(<4 x float>, ptr nocapture, i32 immarg, <4 x i1>) #2
+
+declare <8 x float> @llvm.masked.load.v8f32.p0(ptr nocapture, i32 immarg, <8 x i1>, <8 x float>)
+
+declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr nocapture, i32 immarg, <8 x i1>)
+
+declare <16 x float> @llvm.masked.load.v16f32.p0(ptr nocapture, i32 immarg, <16 x i1>, <16 x float>)
+
+declare void @llvm.masked.store.v16f32.p0(<16 x float>, ptr nocapture, i32 immarg, <16 x i1>)
diff --git a/llvm/test/CodeGen/X86/pr78897.ll b/llvm/test/CodeGen/X86/pr78897.ll
index 56e4ec2bc8ecbb..38a1800df956b5 100644
--- a/llvm/test/CodeGen/X86/pr78897.ll
+++ b/llvm/test/CodeGen/X86/pr78897.ll
@@ -256,8 +256,8 @@ define <16 x i8> @produceShuffleVectorForByte(i8 zeroext %0) nounwind {
 ;
 ; X64-AVX512-LABEL: produceShuffleVectorForByte:
 ; X64-AVX512:       # %bb.0: # %entry
-; X64-AVX512-NEXT:    vpbroadcastb %edi, %xmm0
-; X64-AVX512-NEXT:    vptestnmb {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %k1
+; X64-AVX512-NEXT:    kmovw %edi, %k0
+; X64-AVX512-NEXT:    knotw %k0, %k1
 ; X64-AVX512-NEXT:    vmovdqu8 {{.*#+}} xmm0 {%k1} {z} = [17,17,17,17,17,17,17,17,u,u,u,u,u,u,u,u]
 ; X64-AVX512-NEXT:    vmovq %xmm0, %rax
 ; X64-AVX512-NEXT:    movabsq $1229782938247303440, %rcx # imm = 0x1111111111111110

@abhishek-kaushik22
Copy link
Contributor Author

@e-kud @phoebewang @RKSimon @topperc can you please review?

@abhishek-kaushik22 abhishek-kaushik22 changed the title Generate kmov for masking integers [X86] Generate kmov for masking integers Dec 20, 2024
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests need some cleanup/simplification - ideally its good to precommit the tests with current codegen and then show the codegen change when the combine is added in a later commit.

What was the reason behind performing this in DAGToDAG and not as a regular DAG combine?

; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mcpu=skylake-avx512 | FileCheck %s

define dso_local void @foo_16_ne(ptr nocapture noundef writeonly %c, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %b, i32 noundef %mask) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace foo with a descriptive test name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't think of one, I have replaced foo with the pr number.

%2 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %b, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
%3 = tail call <16 x float> @llvm.masked.load.v16f32.p0(ptr %a, i32 4, <16 x i1> %hir.cmp.45, <16 x float> poison)
%4 = fadd reassoc nsz arcp contract afn <16 x float> %2, %3
tail call void @llvm.masked.store.v16f32.p0(<16 x float> %4, ptr %c, i32 4, <16 x i1> %hir.cmp.45)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are all the masked load/stores necessary? aren't there simpler predicated instruction patterns that could be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the masked load/stores by directly returning the vector, but it didn't really work for vector length 4 where we had a X86ISD::PCMPEQ node which isn't part of this pattern.

@abhishek-kaushik22
Copy link
Contributor Author

Tests need some cleanup/simplification - ideally its good to precommit the tests with current codegen and then show the codegen change when the combine is added in a later commit.

Thanks, I'll take care next time.

What was the reason behind performing this in DAGToDAG and not as a regular DAG combine?

I wasn't sure what this could be combined to, I didn't find any useful patterns that gets selected as KMOV. I only see these patterns in tablegen.

/*544207*/    OPC_MorphNodeTo1None, TARGET_VAL(X86::KMOVWkr),
                  /*MVT::v16i1*/22, 1/*#Ops*/, 5,
              // Src: (insert_subvector:{ *:[v16i1] } immAllZerosV:{ *:[v16i1] }, (scalar_to_vector:{ *:[v1i1] } GR8:{ *:[i8] }:$src), 0:{ *:[iPTR] }) - Complexity = 15
              // Dst: (KMOVWkr:{ *:[v16i1] } (AND32ri:{ *:[i32] }:{ *:[i32] } (INSERT_SUBREG:{ *:[i32] } (IMPLICIT_DEF:{ *:[i32] }), GR8:{ *:[i8] }:$src, sub_8bit:{ *:[i32] }), 1:{ *:[i32] }))

If there's a simpler pattern that I missed, please let me know.

- Remove attributes
- Remove fast math flags
- Simplify tests by removing mask/loads
@RKSimon
Copy link
Collaborator

RKSimon commented Dec 23, 2024

This is what you're matching:

Type-legalized selection DAG: %bb.0 'pr120593_16_ne:'
SelectionDAG has 31 nodes:
  t0: ch,glue = EntryToken
  t2: i32,ch = CopyFromReg t0, Register:i32 %0
          t34: v16i32 = BUILD_VECTOR t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2
          t23: v16i32 = BUILD_VECTOR Constant:i32<1>, Constant:i32<2>, Constant:i32<4>, Constant:i32<8>, Constant:i32<16>, Constant:i32<32>, Constant:i32<64>, Constant:i32<128>, Constant:i32<256>, Constant:i32<512>, Constant:i32<1024>, Constant:i32<2048>, Constant:i32<4096>, Constant:i32<8192>, Constant:i32<16384>, Constant:i32<32768>
        t24: v16i32 = and t34, t23
        t26: v16i32 = BUILD_VECTOR Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>
      t28: v16i1 = setcc t24, t26, setne:ch
    t29: v16i8 = any_extend t28
  t32: ch,glue = CopyToReg t0, Register:v16i8 $xmm0, t29
  t33: ch = X86ISD::RET_GLUE t32, TargetConstant:i32<0>, Register:v16i8 $xmm0, t32:1

You'd probably be better off trying to fold this inside combineSETCC - there's already a number of similar patterns that get folded in there.

What you're after is folding this to something like: (sext v16i8 (v16i1 bitcast(i16 trunc (i32 t2))))

Combine to KMOV instead of doing it in ISEL
@abhishek-kaushik22
Copy link
Contributor Author

This is what you're matching:

Type-legalized selection DAG: %bb.0 'pr120593_16_ne:'
SelectionDAG has 31 nodes:
  t0: ch,glue = EntryToken
  t2: i32,ch = CopyFromReg t0, Register:i32 %0
          t34: v16i32 = BUILD_VECTOR t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2, t2
          t23: v16i32 = BUILD_VECTOR Constant:i32<1>, Constant:i32<2>, Constant:i32<4>, Constant:i32<8>, Constant:i32<16>, Constant:i32<32>, Constant:i32<64>, Constant:i32<128>, Constant:i32<256>, Constant:i32<512>, Constant:i32<1024>, Constant:i32<2048>, Constant:i32<4096>, Constant:i32<8192>, Constant:i32<16384>, Constant:i32<32768>
        t24: v16i32 = and t34, t23
        t26: v16i32 = BUILD_VECTOR Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>
      t28: v16i1 = setcc t24, t26, setne:ch
    t29: v16i8 = any_extend t28
  t32: ch,glue = CopyToReg t0, Register:v16i8 $xmm0, t29
  t33: ch = X86ISD::RET_GLUE t32, TargetConstant:i32<0>, Register:v16i8 $xmm0, t32:1

You'd probably be better off trying to fold this inside combineSETCC - there's already a number of similar patterns that get folded in there.

What you're after is folding this to something like: (sext v16i8 (v16i1 bitcast(i16 trunc (i32 t2))))

Thanks! I've changed to match this pattern as a DAG combine.

@abhishek-kaushik22
Copy link
Contributor Author

@abhishek-kaushik22 Can you confirm if this will fix #72803 please?

Yes, for

define void @example(ptr noalias nocapture noundef writeonly sret([8 x i32]) align 4 dereferenceable(32) %_0, ptr noalias nocapture noundef readonly align 4 dereferenceable(32) %a, i64 noundef %m, ptr noalias nocapture noundef readonly align 4 dereferenceable(32) %b) {
start:
  %0 = insertelement <8 x i64> poison, i64 %m, i64 0
  %1 = shufflevector <8 x i64> %0, <8 x i64> poison, <8 x i32> zeroinitializer
  %2 = and <8 x i64> %1, <i64 1, i64 2, i64 4, i64 8, i64 16, i64 32, i64 64, i64 128>
  %3 = icmp eq <8 x i64> %2, zeroinitializer
  %4 = load <8 x i32>, ptr %b, align 4
  %5 = select <8 x i1> %3, <8 x i32> zeroinitializer, <8 x i32> %4
  %6 = load <8 x i32>, ptr %a, align 4
  %7 = add <8 x i32> %6, %5
  store <8 x i32> %7, ptr %_0, align 4
  ret void
}

this generates

        movq	%rdi, %rax
	kmovd	%edx, %k1
	vmovdqu	(%rsi), %ymm0
	vpaddd	(%rcx), %ymm0, %ymm0 {%k1}
	vmovdqu	%ymm0, (%rdi)
	vzeroupper
	retq

@@ -0,0 +1,372 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=x86_64-- -mcpu=x86-64-v4 | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a -mcpu=knl run to check what happens on AVX512 targets with less ISA support

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a -mcpu=knl run to check what happens on AVX512 targets with less ISA support

This one doesn't really work as expected for most of the cases because of the PCMPEQ node.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor, but I think its almost there

DAG.getConstant(Mask, DL, BroadcastOpVT));
}
// We can't extract more than 16 bits using this pattern, because 2^{17} will
// not fit in an i16 and a vXi32 where X > 16 is more than 512 bits.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the check that VT is not greater than v16i1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the check here: UndefElts.getBitWidth() > 16

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few final minors

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@abhishek-kaushik22 abhishek-kaushik22 merged commit 17857d9 into llvm:main Mar 3, 2025
11 checks passed
@abhishek-kaushik22 abhishek-kaushik22 deleted the kmov branch March 3, 2025 15:05
@abhishek-kaushik22
Copy link
Contributor Author

Thank You @RKSimon :)

jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
When we have an integer used as a bit mask the llvm ir looks something
like this
```
%1 = and <16 x i32> %.splat, <i32 1, i32 2, i32 4, i32 8, i32 16, i32 32, i32 64, i32 128, i32 256, i32 512, i32 1024, i32 2048, i32 4096, i32 8192, i32 16384, i32 32768>
%cmp1 = icmp ne <16 x i32> %1, zeroinitializer
```
where `.splat` is vector containing the mask in all lanes. The assembly
generated for this looks like
```
vpbroadcastd    %ecx, %zmm0
vptestmd        .LCPI0_0(%rip), %zmm0, %k1
```
where we have a constant table of powers of 2.
Instead of doing this we could just move the relevant bits directly to
`k` registers using a `kmov` instruction.
```
kmovw   %ecx, %k1
```
This is faster and also reduces code size.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unnecessary roundtrip through avx512 vector registers for integer mask
4 participants