Skip to content

[SelectionDAG] NFC: Add target hooks to enable vector coercion in CopyToReg / CopyFromReg #66134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

jrbyrnes
Copy link
Contributor

@jrbyrnes jrbyrnes commented Sep 12, 2023

In cases where we are copying vectors across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For non-legal vector types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16).

This NFC patch introduces the target hooks to implement such functionality.

@jrbyrnes jrbyrnes requested review from a team as code owners September 12, 2023 20:01
@llvmbot llvmbot added backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well labels Sep 12, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-llvm-selectiondag

Changes

In cases where we are copying vectors of non-legal types across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For such types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into a single, or fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16).

This NFC patch introduces the target hooks to implement such functionality.

Patch is 37.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66134.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ByteProvider.h (+11-2)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+15-5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+54-32)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4s.ll (+207-14)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4u.ll (+59-17)
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h
index 3187b4e68c56f3a..7fc2d9a876e6710 100644
--- a/llvm/include/llvm/CodeGen/ByteProvider.h
+++ b/llvm/include/llvm/CodeGen/ByteProvider.h
@@ -32,6 +32,11 @@ template  class ByteProvider {
   ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset)
       : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
 
+  ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset,
+               bool IsSigned)
+      : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset),
+        IsSigned(IsSigned) {}
+
   // TODO -- use constraint in c++20
   // Does this type correspond with an operation in selection DAG
   template  class is_op {
@@ -61,13 +66,17 @@ template  class ByteProvider {
   // DestOffset
   int64_t SrcOffset = 0;
 
+  // Tracks whether or not the byte is treated as a signed operand -- useful
+  // for arithmetic combines.
+  bool IsSigned = 0;
+
   ByteProvider() = default;
 
   static ByteProvider getSrc(std::optional Val, int64_t ByteOffset,
-                             int64_t VectorOffset) {
+                             int64_t VectorOffset, bool IsSigned = 0) {
     static_assert(is_op().value,
                   "ByteProviders must contain an operation in selection DAG.");
-    return ByteProvider(Val, ByteOffset, VectorOffset);
+    return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned);
   }
 
   static ByteProvider getConstantZero() {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 12b280d5b1a0bcd..4066fccf2312abe 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1076,10 +1076,10 @@ class TargetLoweringBase {
   /// This method returns the number of registers needed, and the VT for each
   /// register.  It also returns the VT and quantity of the intermediate values
   /// before they are promoted/expanded.
-  unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT,
-                                  EVT &IntermediateVT,
-                                  unsigned &NumIntermediates,
-                                  MVT &RegisterVT) const;
+  virtual unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT,
+                                          EVT &IntermediateVT,
+                                          unsigned &NumIntermediates,
+                                          MVT &RegisterVT) const;
 
   /// Certain targets such as MIPS require that some types such as vectors are
   /// always broken down into scalars in some contexts. This occurs even if the
@@ -1091,6 +1091,16 @@ class TargetLoweringBase {
                                   RegisterVT);
   }
 
+  /// Certain targets, such as AMDGPU, may coerce vectors of one type to another
+  /// to produce optimal code for CopyToReg / CopyFromReg pairs when dealing
+  /// with non-legal types -- e.g. v7i8 -> v2i32. This gives targets an
+  /// opportunity to do custom lowering in such cases.
+  virtual SDValue lowerVectorCopyReg(bool ISABIRegCopy, SelectionDAG &DAG,
+                                     const SDLoc &DL, SDValue &Val, EVT Source,
+                                     EVT Dest, bool IsCopyTo = true) const {
+    return SDValue();
+  };
+
   struct IntrinsicInfo {
     unsigned     opc = 0;          // target opcode
     EVT          memVT;            // memory VT
@@ -1598,7 +1608,7 @@ class TargetLoweringBase {
   }
 
   /// Return the type of registers that this ValueType will eventually require.
-  MVT getRegisterType(LLVMContext &Context, EVT VT) const {
+  virtual MVT getRegisterType(LLVMContext &Context, EVT VT) const {
     if (VT.isSimple())
       return getRegisterType(VT.getSimpleVT());
     if (VT.isVector()) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5a227ba398e1c11..afa6caae889dabb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -400,6 +400,11 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
     if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
       return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
 
+    if (auto TargetLowered = TLI.lowerVectorCopyReg(IsABIRegCopy, DAG, DL, Val,
+                                                    PartEVT, ValueVT, false)) {
+      // Give targets a chance to custom lower mismatched sizes
+      return TargetLowered;
+    }
     // If the parts vector has more elements than the value vector, then we
     // have a vector widening case (e.g. <2 x float> -> <4 x float>).
     // Extract the elements we want.
@@ -765,6 +770,10 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL,
   } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) {
     // Bitconvert vector->vector case.
     Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val);
+  } else if (SDValue TargetLowered = TLI.lowerVectorCopyReg(
+                 IsABIRegCopy, DAG, DL, Val, ValueVT, BuiltVectorTy)) {
+    // Give targets a chance to custom lower mismatched sizes
+    Val = TargetLowered;
   } else {
     if (BuiltVectorTy.getVectorElementType().bitsGT(
             ValueVT.getVectorElementType())) {
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 85c9ed489e926ce..a7116518aaadd0d 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10513,7 +10513,7 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
 // performed.
 static const std::optional>
 calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
-                 unsigned Depth = 0) {
+                 bool IsSigned = 0, unsigned Depth = 0) {
   // We may need to recursively traverse a series of SRLs
   if (Depth >= 6)
     return std::nullopt;
@@ -10524,12 +10524,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
   switch (Op->getOpcode()) {
   case ISD::TRUNCATE: {
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
   case ISD::SIGN_EXTEND_INREG: {
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
     SDValue NarrowOp = Op->getOperand(0);
     auto NarrowVT = NarrowOp.getValueType();
     if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
@@ -10542,7 +10545,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     if (SrcIndex >= NarrowByteWidth)
       return std::nullopt;
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SRA:
@@ -10558,11 +10562,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     SrcIndex += BitShift / 8;
 
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   default: {
-    return ByteProvider::getSrc(Op, DestByte, SrcIndex);
+    if (auto L = dyn_cast(Op))
+      IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
+
+    return ByteProvider::getSrc(Op, DestByte, SrcIndex, IsSigned);
   }
   }
   llvm_unreachable("fully handled switch");
@@ -10576,7 +10584,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 // performed. \p StartingIndex is the originally requested byte of the Or
 static const std::optional>
 calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
-                      unsigned StartingIndex = 0) {
+                      unsigned StartingIndex = 0, bool IsSigned = 0) {
   // Finding Src tree of RHS of or typically requires at least 1 additional
   // depth
   if (Depth > 6)
@@ -10591,11 +10599,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   switch (Op.getOpcode()) {
   case ISD::OR: {
     auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!RHS)
       return std::nullopt;
     auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!LHS)
       return std::nullopt;
     // A well formed Or will have two ByteProviders for each byte, one of which
@@ -10626,7 +10634,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
       return ByteProvider::getConstantZero();
     }
 
-    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
+    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned);
   }
 
   case ISD::SRA:
@@ -10651,7 +10659,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     // the SRL is Index + ByteShift
     return BytesProvided - ByteShift > Index
                ? calculateSrcByte(Op->getOperand(0), StartingIndex,
-                                  Index + ByteShift)
+                                  Index + ByteShift, IsSigned)
                : ByteProvider::getConstantZero();
   }
 
@@ -10672,7 +10680,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     return Index < ByteShift
                ? ByteProvider::getConstantZero()
                : calculateByteProvider(Op.getOperand(0), Index - ByteShift,
-                                       Depth + 1, StartingIndex);
+                                       Depth + 1, StartingIndex, IsSigned);
   }
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -10691,13 +10699,17 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
+                Op->getOpcode() == ISD::AssertSext;
 
     if (Index >= NarrowByteWidth)
       return Op.getOpcode() == ISD::ZERO_EXTEND
                  ? std::optional>(
                        ByteProvider::getConstantZero())
                  : std::nullopt;
-    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
+    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
+                                 IsSigned);
   }
 
   case ISD::TRUNCATE: {
@@ -10705,7 +10717,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
     if (NarrowByteWidth >= Index) {
       return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                   StartingIndex);
+                                   StartingIndex, IsSigned);
     }
 
     return std::nullopt;
@@ -10713,13 +10725,14 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::CopyFromReg: {
     if (BitWidth / 8 > Index)
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
 
     return std::nullopt;
   }
 
   case ISD::LOAD: {
     auto L = cast(Op.getNode());
+    IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
@@ -10736,7 +10749,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     if (NarrowByteWidth > Index) {
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
     }
 
     return std::nullopt;
@@ -10744,7 +10757,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::BSWAP:
     return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
-                                 Depth + 1, StartingIndex);
+                                 Depth + 1, StartingIndex, IsSigned);
 
   case ISD::EXTRACT_VECTOR_ELT: {
     auto IdxOp = dyn_cast(Op->getOperand(1));
@@ -10759,7 +10772,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
-                            StartingIndex, Index);
+                            StartingIndex, Index, IsSigned);
   }
 
   case AMDGPUISD::PERM: {
@@ -10775,9 +10788,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
     auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;
 
-    return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
-                           : ByteProvider(
-                                 ByteProvider::getConstantZero());
+    return IdxMask != 0x0c
+               ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
+               : ByteProvider(
+                     ByteProvider::getConstantZero());
   }
 
   default: {
@@ -12587,11 +12601,7 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
     auto MulIdx = isMul(LHS) ? 0 : 1;
 
     auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
-    bool IsSigned =
-        MulOpcode == AMDGPUISD::MUL_I24 ||
-        (MulOpcode == ISD::MUL &&
-         TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
-         !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+    std::optional IsSigned;
     SmallVector, 4> Src0s;
     SmallVector, 4> Src1s;
     SmallVector Src2s;
@@ -12607,15 +12617,17 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
           (MulOpcode == ISD::MUL &&
            TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
            !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
-      if (IterIsSigned != IsSigned) {
-        break;
-      }
       auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
       if (!Src0)
         break;
       auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
       if (!Src1)
         break;
+      IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+      if (!IsSigned)
+        IsSigned = IterIsSigned;
+      if (IterIsSigned != *IsSigned)
+        break;
       placeSources(*Src0, *Src1, Src0s, Src1s, I);
       auto AddIdx = 1 - MulIdx;
       // Allow the special case where add (add (mul24, 0), mul24) became ->
@@ -12630,6 +12642,15 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
             handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
         if (!Src1)
           break;
+        auto IterIsSigned =
+            MulOpcode == AMDGPUISD::MUL_I24 ||
+            (MulOpcode == ISD::MUL &&
+             TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
+             !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+        IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+        assert(IsSigned);
+        if (IterIsSigned != *IsSigned)
+          break;
         placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
         Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
         ChainLength = I + 2;
@@ -12695,18 +12716,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
       Src1 = resolveSources(DAG, SL, Src1s, false, true);
     }
 
+    assert(IsSigned);
     SDValue Src2 =
-        DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
+        DAG.getExtOrTrunc(*IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
 
-    SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
-                                                 : Intrinsic::amdgcn_udot4,
+    SDValue IID = DAG.getTargetConstant(*IsSigned ? Intrinsic::amdgcn_sdot4
+                                                  : Intrinsic::amdgcn_udot4,
                                         SL, MVT::i64);
 
     assert(!VT.isVector());
     auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
                            Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));
 
-    return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
+    return DAG.getExtOrTrunc(*IsSigned, Dot, SL, VT);
   }
 
   if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/AMDGPU/idot4s.ll b/llvm/test/CodeGen/AMDGPU/idot4s.ll
index 7edd24f12982ebd..e521039ce9ac838 100644
--- a/llvm/test/CodeGen/AMDGPU/idot4s.ll
+++ b/llvm/test/CodeGen/AMDGPU/idot4s.ll
@@ -143,7 +143,7 @@ define amdgpu_kernel void @idot4_acc32(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -352,7 +352,7 @@ define amdgpu_kernel void @idot4_acc16(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    global_load_i16 v3, v1, s[0:1]
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b16 v1, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -732,7 +732,7 @@ define amdgpu_kernel void @idot4_multiuse_mul1(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
 ; GFX11-DL-NEXT:    v_mad_i32_i24 v2, v2, v3, s2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v3, 0
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v3, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -922,7 +922,7 @@ define amdgpu_kernel void @idot4_acc32_vecMul(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1356,7 +1356,7 @@ define amdgpu_kernel void @idot4_acc32_2ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc0c0100
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1534,7 +1534,7 @@ define amdgpu_kernel void @idot4_acc32_3ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0,...

…yToReg / CopyFromReg

Change-Id: I7c888ad2c3cd7f1104aed47725852e2fe09b7665
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-backend-amdgpu

Changes

In cases where we are copying vectors of non-legal types across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For such types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into a single, or fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16).

This NFC patch introduces the target hooks to implement such functionality.

Patch is 37.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66134.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ByteProvider.h (+11-2)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+15-5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+9)
  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+54-32)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4s.ll (+207-14)
  • (modified) llvm/test/CodeGen/AMDGPU/idot4u.ll (+59-17)
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h
index 3187b4e68c56f3a..7fc2d9a876e6710 100644
--- a/llvm/include/llvm/CodeGen/ByteProvider.h
+++ b/llvm/include/llvm/CodeGen/ByteProvider.h
@@ -32,6 +32,11 @@ template  class ByteProvider {
   ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset)
       : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}
 
+  ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset,
+               bool IsSigned)
+      : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset),
+        IsSigned(IsSigned) {}
+
   // TODO -- use constraint in c++20
   // Does this type correspond with an operation in selection DAG
   template  class is_op {
@@ -61,13 +66,17 @@ template  class ByteProvider {
   // DestOffset
   int64_t SrcOffset = 0;
 
+  // Tracks whether or not the byte is treated as a signed operand -- useful
+  // for arithmetic combines.
+  bool IsSigned = 0;
+
   ByteProvider() = default;
 
   static ByteProvider getSrc(std::optional Val, int64_t ByteOffset,
-                             int64_t VectorOffset) {
+                             int64_t VectorOffset, bool IsSigned = 0) {
     static_assert(is_op().value,
                   "ByteProviders must contain an operation in selection DAG.");
-    return ByteProvider(Val, ByteOffset, VectorOffset);
+    return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned);
   }
 
   static ByteProvider getConstantZero() {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 12b280d5b1a0bcd..4066fccf2312abe 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1076,10 +1076,10 @@ class TargetLoweringBase {
   /// This method returns the number of registers needed, and the VT for each
   /// register.  It also returns the VT and quantity of the intermediate values
   /// before they are promoted/expanded.
-  unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT,
-                                  EVT &IntermediateVT,
-                                  unsigned &NumIntermediates,
-                                  MVT &RegisterVT) const;
+  virtual unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT,
+                                          EVT &IntermediateVT,
+                                          unsigned &NumIntermediates,
+                                          MVT &RegisterVT) const;
 
   /// Certain targets such as MIPS require that some types such as vectors are
   /// always broken down into scalars in some contexts. This occurs even if the
@@ -1091,6 +1091,16 @@ class TargetLoweringBase {
                                   RegisterVT);
   }
 
+  /// Certain targets, such as AMDGPU, may coerce vectors of one type to another
+  /// to produce optimal code for CopyToReg / CopyFromReg pairs when dealing
+  /// with non-legal types -- e.g. v7i8 -> v2i32. This gives targets an
+  /// opportunity to do custom lowering in such cases.
+  virtual SDValue lowerVectorCopyReg(bool ISABIRegCopy, SelectionDAG &DAG,
+                                     const SDLoc &DL, SDValue &Val, EVT Source,
+                                     EVT Dest, bool IsCopyTo = true) const {
+    return SDValue();
+  };
+
   struct IntrinsicInfo {
     unsigned     opc = 0;          // target opcode
     EVT          memVT;            // memory VT
@@ -1598,7 +1608,7 @@ class TargetLoweringBase {
   }
 
   /// Return the type of registers that this ValueType will eventually require.
-  MVT getRegisterType(LLVMContext &Context, EVT VT) const {
+  virtual MVT getRegisterType(LLVMContext &Context, EVT VT) const {
     if (VT.isSimple())
       return getRegisterType(VT.getSimpleVT());
     if (VT.isVector()) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5a227ba398e1c11..afa6caae889dabb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -400,6 +400,11 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
     if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
       return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
 
+    if (auto TargetLowered = TLI.lowerVectorCopyReg(IsABIRegCopy, DAG, DL, Val,
+                                                    PartEVT, ValueVT, false)) {
+      // Give targets a chance to custom lower mismatched sizes
+      return TargetLowered;
+    }
     // If the parts vector has more elements than the value vector, then we
     // have a vector widening case (e.g. <2 x float> -> <4 x float>).
     // Extract the elements we want.
@@ -765,6 +770,10 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL,
   } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) {
     // Bitconvert vector->vector case.
     Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val);
+  } else if (SDValue TargetLowered = TLI.lowerVectorCopyReg(
+                 IsABIRegCopy, DAG, DL, Val, ValueVT, BuiltVectorTy)) {
+    // Give targets a chance to custom lower mismatched sizes
+    Val = TargetLowered;
   } else {
     if (BuiltVectorTy.getVectorElementType().bitsGT(
             ValueVT.getVectorElementType())) {
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 85c9ed489e926ce..a7116518aaadd0d 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10513,7 +10513,7 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
 // performed.
 static const std::optional>
 calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
-                 unsigned Depth = 0) {
+                 bool IsSigned = 0, unsigned Depth = 0) {
   // We may need to recursively traverse a series of SRLs
   if (Depth >= 6)
     return std::nullopt;
@@ -10524,12 +10524,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
   switch (Op->getOpcode()) {
   case ISD::TRUNCATE: {
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
   case ISD::SIGN_EXTEND_INREG: {
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
     SDValue NarrowOp = Op->getOperand(0);
     auto NarrowVT = NarrowOp.getValueType();
     if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
@@ -10542,7 +10545,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     if (SrcIndex >= NarrowByteWidth)
       return std::nullopt;
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   case ISD::SRA:
@@ -10558,11 +10562,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 
     SrcIndex += BitShift / 8;
 
-    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
+    return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
+                            Depth + 1);
   }
 
   default: {
-    return ByteProvider::getSrc(Op, DestByte, SrcIndex);
+    if (auto L = dyn_cast(Op))
+      IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
+
+    return ByteProvider::getSrc(Op, DestByte, SrcIndex, IsSigned);
   }
   }
   llvm_unreachable("fully handled switch");
@@ -10576,7 +10584,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
 // performed. \p StartingIndex is the originally requested byte of the Or
 static const std::optional>
 calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
-                      unsigned StartingIndex = 0) {
+                      unsigned StartingIndex = 0, bool IsSigned = 0) {
   // Finding Src tree of RHS of or typically requires at least 1 additional
   // depth
   if (Depth > 6)
@@ -10591,11 +10599,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   switch (Op.getOpcode()) {
   case ISD::OR: {
     auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!RHS)
       return std::nullopt;
     auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                     StartingIndex);
+                                     StartingIndex, IsSigned);
     if (!LHS)
       return std::nullopt;
     // A well formed Or will have two ByteProviders for each byte, one of which
@@ -10626,7 +10634,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
       return ByteProvider::getConstantZero();
     }
 
-    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
+    return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned);
   }
 
   case ISD::SRA:
@@ -10651,7 +10659,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     // the SRL is Index + ByteShift
     return BytesProvided - ByteShift > Index
                ? calculateSrcByte(Op->getOperand(0), StartingIndex,
-                                  Index + ByteShift)
+                                  Index + ByteShift, IsSigned)
                : ByteProvider::getConstantZero();
   }
 
@@ -10672,7 +10680,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     return Index < ByteShift
                ? ByteProvider::getConstantZero()
                : calculateByteProvider(Op.getOperand(0), Index - ByteShift,
-                                       Depth + 1, StartingIndex);
+                                       Depth + 1, StartingIndex, IsSigned);
   }
   case ISD::ANY_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -10691,13 +10699,17 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
     uint64_t NarrowByteWidth = NarrowBitWidth / 8;
+    IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND ||
+                Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
+                Op->getOpcode() == ISD::AssertSext;
 
     if (Index >= NarrowByteWidth)
       return Op.getOpcode() == ISD::ZERO_EXTEND
                  ? std::optional>(
                        ByteProvider::getConstantZero())
                  : std::nullopt;
-    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
+    return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
+                                 IsSigned);
   }
 
   case ISD::TRUNCATE: {
@@ -10705,7 +10717,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
     if (NarrowByteWidth >= Index) {
       return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
-                                   StartingIndex);
+                                   StartingIndex, IsSigned);
     }
 
     return std::nullopt;
@@ -10713,13 +10725,14 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::CopyFromReg: {
     if (BitWidth / 8 > Index)
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
 
     return std::nullopt;
   }
 
   case ISD::LOAD: {
     auto L = cast(Op.getNode());
+    IsSigned |= L->getExtensionType() == ISD::SEXTLOAD;
     unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
     if (NarrowBitWidth % 8 != 0)
       return std::nullopt;
@@ -10736,7 +10749,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     if (NarrowByteWidth > Index) {
-      return calculateSrcByte(Op, StartingIndex, Index);
+      return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
     }
 
     return std::nullopt;
@@ -10744,7 +10757,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::BSWAP:
     return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
-                                 Depth + 1, StartingIndex);
+                                 Depth + 1, StartingIndex, IsSigned);
 
   case ISD::EXTRACT_VECTOR_ELT: {
     auto IdxOp = dyn_cast(Op->getOperand(1));
@@ -10759,7 +10772,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     }
 
     return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
-                            StartingIndex, Index);
+                            StartingIndex, Index, IsSigned);
   }
 
   case AMDGPUISD::PERM: {
@@ -10775,9 +10788,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
     auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;
 
-    return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
-                           : ByteProvider(
-                                 ByteProvider::getConstantZero());
+    return IdxMask != 0x0c
+               ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
+               : ByteProvider(
+                     ByteProvider::getConstantZero());
   }
 
   default: {
@@ -12587,11 +12601,7 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
     auto MulIdx = isMul(LHS) ? 0 : 1;
 
     auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
-    bool IsSigned =
-        MulOpcode == AMDGPUISD::MUL_I24 ||
-        (MulOpcode == ISD::MUL &&
-         TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
-         !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+    std::optional IsSigned;
     SmallVector, 4> Src0s;
     SmallVector, 4> Src1s;
     SmallVector Src2s;
@@ -12607,15 +12617,17 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
           (MulOpcode == ISD::MUL &&
            TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
            !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
-      if (IterIsSigned != IsSigned) {
-        break;
-      }
       auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
       if (!Src0)
         break;
       auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
       if (!Src1)
         break;
+      IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+      if (!IsSigned)
+        IsSigned = IterIsSigned;
+      if (IterIsSigned != *IsSigned)
+        break;
       placeSources(*Src0, *Src1, Src0s, Src1s, I);
       auto AddIdx = 1 - MulIdx;
       // Allow the special case where add (add (mul24, 0), mul24) became ->
@@ -12630,6 +12642,15 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
             handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
         if (!Src1)
           break;
+        auto IterIsSigned =
+            MulOpcode == AMDGPUISD::MUL_I24 ||
+            (MulOpcode == ISD::MUL &&
+             TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
+             !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
+        IterIsSigned |= Src0->IsSigned || Src1->IsSigned;
+        assert(IsSigned);
+        if (IterIsSigned != *IsSigned)
+          break;
         placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
         Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
         ChainLength = I + 2;
@@ -12695,18 +12716,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
       Src1 = resolveSources(DAG, SL, Src1s, false, true);
     }
 
+    assert(IsSigned);
     SDValue Src2 =
-        DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
+        DAG.getExtOrTrunc(*IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
 
-    SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
-                                                 : Intrinsic::amdgcn_udot4,
+    SDValue IID = DAG.getTargetConstant(*IsSigned ? Intrinsic::amdgcn_sdot4
+                                                  : Intrinsic::amdgcn_udot4,
                                         SL, MVT::i64);
 
     assert(!VT.isVector());
     auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
                            Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));
 
-    return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
+    return DAG.getExtOrTrunc(*IsSigned, Dot, SL, VT);
   }
 
   if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
diff --git a/llvm/test/CodeGen/AMDGPU/idot4s.ll b/llvm/test/CodeGen/AMDGPU/idot4s.ll
index 7edd24f12982ebd..e521039ce9ac838 100644
--- a/llvm/test/CodeGen/AMDGPU/idot4s.ll
+++ b/llvm/test/CodeGen/AMDGPU/idot4s.ll
@@ -143,7 +143,7 @@ define amdgpu_kernel void @idot4_acc32(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -352,7 +352,7 @@ define amdgpu_kernel void @idot4_acc16(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    global_load_i16 v3, v1, s[0:1]
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v2, v0, v3 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b16 v1, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -732,7 +732,7 @@ define amdgpu_kernel void @idot4_multiuse_mul1(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
 ; GFX11-DL-NEXT:    v_mad_i32_i24 v2, v2, v3, s2
 ; GFX11-DL-NEXT:    v_mov_b32_e32 v3, 0
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v3, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -922,7 +922,7 @@ define amdgpu_kernel void @idot4_acc32_vecMul(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    global_load_b32 v0, v0, s[6:7]
 ; GFX11-DL-NEXT:    s_load_b32 s2, s[0:1], 0x0
 ; GFX11-DL-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1356,7 +1356,7 @@ define amdgpu_kernel void @idot4_acc32_2ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0, v0, 0xc0c0100
 ; GFX11-DL-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX11-DL-NEXT:    s_delay_alu instid0(VALU_DEP_1)
-; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2
+; GFX11-DL-NEXT:    v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0]
 ; GFX11-DL-NEXT:    global_store_b32 v2, v0, s[0:1]
 ; GFX11-DL-NEXT:    s_nop 0
 ; GFX11-DL-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
@@ -1534,7 +1534,7 @@ define amdgpu_kernel void @idot4_acc32_3ele(ptr addrspace(1) %src1,
 ; GFX11-DL-NEXT:    v_perm_b32 v0, v0,...

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might be better off just writing an IR pass to pre-optimize these to legal integers

@jrbyrnes
Copy link
Contributor Author

We might be better off just writing an IR pass to pre-optimize these to legal integers

I'll look into the details of this

@jrbyrnes
Copy link
Contributor Author

For a demonstration of the issue --
https://godbolt.org/z/Y7MhcjGE8

@arsenm
Copy link
Contributor

arsenm commented Sep 14, 2023

We might be better off just writing an IR pass to pre-optimize these to legal integers

I'll look into the details of this

And by IR pass I mean AMDGPUCodeGenPrepare

@arsenm
Copy link
Contributor

arsenm commented Jan 11, 2024

Is this still needed?

@jrbyrnes
Copy link
Contributor Author

Abandoning in favor of #66838

@jrbyrnes jrbyrnes closed this Jan 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants