-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SelectionDAG] NFC: Add target hooks to enable vector coercion in CopyToReg / CopyFromReg #66134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag ChangesIn cases where we are copying vectors of non-legal types across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For such types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into a single, or fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16). This NFC patch introduces the target hooks to implement such functionality.Patch is 37.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66134.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h index 3187b4e68c56f3a..7fc2d9a876e6710 100644 --- a/llvm/include/llvm/CodeGen/ByteProvider.h +++ b/llvm/include/llvm/CodeGen/ByteProvider.h @@ -32,6 +32,11 @@ template class ByteProvider { ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset) : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {} + ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset, + bool IsSigned) + : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset), + IsSigned(IsSigned) {} + // TODO -- use constraint in c++20 // Does this type correspond with an operation in selection DAG template class is_op { @@ -61,13 +66,17 @@ template class ByteProvider { // DestOffset int64_t SrcOffset = 0; + // Tracks whether or not the byte is treated as a signed operand -- useful + // for arithmetic combines. + bool IsSigned = 0; + ByteProvider() = default; static ByteProvider getSrc(std::optional Val, int64_t ByteOffset, - int64_t VectorOffset) { + int64_t VectorOffset, bool IsSigned = 0) { static_assert(is_op().value, "ByteProviders must contain an operation in selection DAG."); - return ByteProvider(Val, ByteOffset, VectorOffset); + return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned); } static ByteProvider getConstantZero() { diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 12b280d5b1a0bcd..4066fccf2312abe 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1076,10 +1076,10 @@ class TargetLoweringBase { /// This method returns the number of registers needed, and the VT for each /// register. It also returns the VT and quantity of the intermediate values /// before they are promoted/expanded. - unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT, - EVT &IntermediateVT, - unsigned &NumIntermediates, - MVT &RegisterVT) const; + virtual unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT, + EVT &IntermediateVT, + unsigned &NumIntermediates, + MVT &RegisterVT) const; /// Certain targets such as MIPS require that some types such as vectors are /// always broken down into scalars in some contexts. This occurs even if the @@ -1091,6 +1091,16 @@ class TargetLoweringBase { RegisterVT); } + /// Certain targets, such as AMDGPU, may coerce vectors of one type to another + /// to produce optimal code for CopyToReg / CopyFromReg pairs when dealing + /// with non-legal types -- e.g. v7i8 -> v2i32. This gives targets an + /// opportunity to do custom lowering in such cases. + virtual SDValue lowerVectorCopyReg(bool ISABIRegCopy, SelectionDAG &DAG, + const SDLoc &DL, SDValue &Val, EVT Source, + EVT Dest, bool IsCopyTo = true) const { + return SDValue(); + }; + struct IntrinsicInfo { unsigned opc = 0; // target opcode EVT memVT; // memory VT @@ -1598,7 +1608,7 @@ class TargetLoweringBase { } /// Return the type of registers that this ValueType will eventually require. - MVT getRegisterType(LLVMContext &Context, EVT VT) const { + virtual MVT getRegisterType(LLVMContext &Context, EVT VT) const { if (VT.isSimple()) return getRegisterType(VT.getSimpleVT()); if (VT.isVector()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 5a227ba398e1c11..afa6caae889dabb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -400,6 +400,11 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL, if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + if (auto TargetLowered = TLI.lowerVectorCopyReg(IsABIRegCopy, DAG, DL, Val, + PartEVT, ValueVT, false)) { + // Give targets a chance to custom lower mismatched sizes + return TargetLowered; + } // If the parts vector has more elements than the value vector, then we // have a vector widening case (e.g. <2 x float> -> <4 x float>). // Extract the elements we want. @@ -765,6 +770,10 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL, } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) { // Bitconvert vector->vector case. Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val); + } else if (SDValue TargetLowered = TLI.lowerVectorCopyReg( + IsABIRegCopy, DAG, DL, Val, ValueVT, BuiltVectorTy)) { + // Give targets a chance to custom lower mismatched sizes + Val = TargetLowered; } else { if (BuiltVectorTy.getVectorElementType().bitsGT( ValueVT.getVectorElementType())) { diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 85c9ed489e926ce..a7116518aaadd0d 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10513,7 +10513,7 @@ SDValue SITargetLowering::performAndCombine(SDNode *N, // performed. static const std::optional> calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, - unsigned Depth = 0) { + bool IsSigned = 0, unsigned Depth = 0) { // We may need to recursively traverse a series of SRLs if (Depth >= 6) return std::nullopt; @@ -10524,12 +10524,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, switch (Op->getOpcode()) { case ISD::TRUNCATE: { - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND_INREG: { + IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND || + Op->getOpcode() == ISD::SIGN_EXTEND_INREG; SDValue NarrowOp = Op->getOperand(0); auto NarrowVT = NarrowOp.getValueType(); if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) { @@ -10542,7 +10545,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, if (SrcIndex >= NarrowByteWidth) return std::nullopt; - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } case ISD::SRA: @@ -10558,11 +10562,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, SrcIndex += BitShift / 8; - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } default: { - return ByteProvider::getSrc(Op, DestByte, SrcIndex); + if (auto L = dyn_cast(Op)) + IsSigned |= L->getExtensionType() == ISD::SEXTLOAD; + + return ByteProvider::getSrc(Op, DestByte, SrcIndex, IsSigned); } } llvm_unreachable("fully handled switch"); @@ -10576,7 +10584,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, // performed. \p StartingIndex is the originally requested byte of the Or static const std::optional> calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, - unsigned StartingIndex = 0) { + unsigned StartingIndex = 0, bool IsSigned = 0) { // Finding Src tree of RHS of or typically requires at least 1 additional // depth if (Depth > 6) @@ -10591,11 +10599,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, switch (Op.getOpcode()) { case ISD::OR: { auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); if (!RHS) return std::nullopt; auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); if (!LHS) return std::nullopt; // A well formed Or will have two ByteProviders for each byte, one of which @@ -10626,7 +10634,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, return ByteProvider::getConstantZero(); } - return calculateSrcByte(Op->getOperand(0), StartingIndex, Index); + return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned); } case ISD::SRA: @@ -10651,7 +10659,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, // the SRL is Index + ByteShift return BytesProvided - ByteShift > Index ? calculateSrcByte(Op->getOperand(0), StartingIndex, - Index + ByteShift) + Index + ByteShift, IsSigned) : ByteProvider::getConstantZero(); } @@ -10672,7 +10680,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, return Index < ByteShift ? ByteProvider::getConstantZero() : calculateByteProvider(Op.getOperand(0), Index - ByteShift, - Depth + 1, StartingIndex); + Depth + 1, StartingIndex, IsSigned); } case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: @@ -10691,13 +10699,17 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, if (NarrowBitWidth % 8 != 0) return std::nullopt; uint64_t NarrowByteWidth = NarrowBitWidth / 8; + IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND || + Op->getOpcode() == ISD::SIGN_EXTEND_INREG || + Op->getOpcode() == ISD::AssertSext; if (Index >= NarrowByteWidth) return Op.getOpcode() == ISD::ZERO_EXTEND ? std::optional>( ByteProvider::getConstantZero()) : std::nullopt; - return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex); + return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex, + IsSigned); } case ISD::TRUNCATE: { @@ -10705,7 +10717,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, if (NarrowByteWidth >= Index) { return calculateByteProvider(Op.getOperand(0), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); } return std::nullopt; @@ -10713,13 +10725,14 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, case ISD::CopyFromReg: { if (BitWidth / 8 > Index) - return calculateSrcByte(Op, StartingIndex, Index); + return calculateSrcByte(Op, StartingIndex, Index, IsSigned); return std::nullopt; } case ISD::LOAD: { auto L = cast(Op.getNode()); + IsSigned |= L->getExtensionType() == ISD::SEXTLOAD; unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits(); if (NarrowBitWidth % 8 != 0) return std::nullopt; @@ -10736,7 +10749,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, } if (NarrowByteWidth > Index) { - return calculateSrcByte(Op, StartingIndex, Index); + return calculateSrcByte(Op, StartingIndex, Index, IsSigned); } return std::nullopt; @@ -10744,7 +10757,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, case ISD::BSWAP: return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1, - Depth + 1, StartingIndex); + Depth + 1, StartingIndex, IsSigned); case ISD::EXTRACT_VECTOR_ELT: { auto IdxOp = dyn_cast(Op->getOperand(1)); @@ -10759,7 +10772,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, } return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0), - StartingIndex, Index); + StartingIndex, Index, IsSigned); } case AMDGPUISD::PERM: { @@ -10775,9 +10788,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1); auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask; - return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex) - : ByteProvider( - ByteProvider::getConstantZero()); + return IdxMask != 0x0c + ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned) + : ByteProvider( + ByteProvider::getConstantZero()); } default: { @@ -12587,11 +12601,7 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, auto MulIdx = isMul(LHS) ? 0 : 1; auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode(); - bool IsSigned = - MulOpcode == AMDGPUISD::MUL_I24 || - (MulOpcode == ISD::MUL && - TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && - !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); + std::optional IsSigned; SmallVector, 4> Src0s; SmallVector, 4> Src1s; SmallVector Src2s; @@ -12607,15 +12617,17 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, (MulOpcode == ISD::MUL && TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); - if (IterIsSigned != IsSigned) { - break; - } auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0)); if (!Src0) break; auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1)); if (!Src1) break; + IterIsSigned |= Src0->IsSigned || Src1->IsSigned; + if (!IsSigned) + IsSigned = IterIsSigned; + if (IterIsSigned != *IsSigned) + break; placeSources(*Src0, *Src1, Src0s, Src1s, I); auto AddIdx = 1 - MulIdx; // Allow the special case where add (add (mul24, 0), mul24) became -> @@ -12630,6 +12642,15 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1)); if (!Src1) break; + auto IterIsSigned = + MulOpcode == AMDGPUISD::MUL_I24 || + (MulOpcode == ISD::MUL && + TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && + !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); + IterIsSigned |= Src0->IsSigned || Src1->IsSigned; + assert(IsSigned); + if (IterIsSigned != *IsSigned) + break; placeSources(*Src0, *Src1, Src0s, Src1s, I + 1); Src2s.push_back(DAG.getConstant(0, SL, MVT::i32)); ChainLength = I + 2; @@ -12695,18 +12716,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, Src1 = resolveSources(DAG, SL, Src1s, false, true); } + assert(IsSigned); SDValue Src2 = - DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32); + DAG.getExtOrTrunc(*IsSigned, Src2s[ChainLength - 1], SL, MVT::i32); - SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4 - : Intrinsic::amdgcn_udot4, + SDValue IID = DAG.getTargetConstant(*IsSigned ? Intrinsic::amdgcn_sdot4 + : Intrinsic::amdgcn_udot4, SL, MVT::i64); assert(!VT.isVector()); auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0, Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1)); - return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT); + return DAG.getExtOrTrunc(*IsSigned, Dot, SL, VT); } if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG()) diff --git a/llvm/test/CodeGen/AMDGPU/idot4s.ll b/llvm/test/CodeGen/AMDGPU/idot4s.ll index 7edd24f12982ebd..e521039ce9ac838 100644 --- a/llvm/test/CodeGen/AMDGPU/idot4s.ll +++ b/llvm/test/CodeGen/AMDGPU/idot4s.ll @@ -143,7 +143,7 @@ define amdgpu_kernel void @idot4_acc32(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: s_load_b32 s2, s[0:1], 0x0 ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -352,7 +352,7 @@ define amdgpu_kernel void @idot4_acc16(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: global_load_i16 v3, v1, s[0:1] ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v2, v0, v3 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v2, v0, v3 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b16 v1, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -732,7 +732,7 @@ define amdgpu_kernel void @idot4_multiuse_mul1(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2) ; GFX11-DL-NEXT: v_mad_i32_i24 v2, v2, v3, s2 ; GFX11-DL-NEXT: v_mov_b32_e32 v3, 0 -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, v2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v3, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -922,7 +922,7 @@ define amdgpu_kernel void @idot4_acc32_vecMul(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: s_load_b32 s2, s[0:1], 0x0 ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -1356,7 +1356,7 @@ define amdgpu_kernel void @idot4_acc32_2ele(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: v_perm_b32 v0, v0, v0, 0xc0c0100 ; GFX11-DL-NEXT: s_waitcnt lgkmcnt(0) ; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_1) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v0, v1, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -1534,7 +1534,7 @@ define amdgpu_kernel void @idot4_acc32_3ele(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: v_perm_b32 v0, v0,... |
…yToReg / CopyFromReg Change-Id: I7c888ad2c3cd7f1104aed47725852e2fe09b7665
@llvm/pr-subscribers-backend-amdgpu ChangesIn cases where we are copying vectors of non-legal types across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For such types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into a single, or fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16). This NFC patch introduces the target hooks to implement such functionality.Patch is 37.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66134.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h index 3187b4e68c56f3a..7fc2d9a876e6710 100644 --- a/llvm/include/llvm/CodeGen/ByteProvider.h +++ b/llvm/include/llvm/CodeGen/ByteProvider.h @@ -32,6 +32,11 @@ template class ByteProvider { ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset) : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {} + ByteProvider(std::optional Src, int64_t DestOffset, int64_t SrcOffset, + bool IsSigned) + : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset), + IsSigned(IsSigned) {} + // TODO -- use constraint in c++20 // Does this type correspond with an operation in selection DAG template class is_op { @@ -61,13 +66,17 @@ template class ByteProvider { // DestOffset int64_t SrcOffset = 0; + // Tracks whether or not the byte is treated as a signed operand -- useful + // for arithmetic combines. + bool IsSigned = 0; + ByteProvider() = default; static ByteProvider getSrc(std::optional Val, int64_t ByteOffset, - int64_t VectorOffset) { + int64_t VectorOffset, bool IsSigned = 0) { static_assert(is_op().value, "ByteProviders must contain an operation in selection DAG."); - return ByteProvider(Val, ByteOffset, VectorOffset); + return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned); } static ByteProvider getConstantZero() { diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 12b280d5b1a0bcd..4066fccf2312abe 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1076,10 +1076,10 @@ class TargetLoweringBase { /// This method returns the number of registers needed, and the VT for each /// register. It also returns the VT and quantity of the intermediate values /// before they are promoted/expanded. - unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT, - EVT &IntermediateVT, - unsigned &NumIntermediates, - MVT &RegisterVT) const; + virtual unsigned getVectorTypeBreakdown(LLVMContext &Context, EVT VT, + EVT &IntermediateVT, + unsigned &NumIntermediates, + MVT &RegisterVT) const; /// Certain targets such as MIPS require that some types such as vectors are /// always broken down into scalars in some contexts. This occurs even if the @@ -1091,6 +1091,16 @@ class TargetLoweringBase { RegisterVT); } + /// Certain targets, such as AMDGPU, may coerce vectors of one type to another + /// to produce optimal code for CopyToReg / CopyFromReg pairs when dealing + /// with non-legal types -- e.g. v7i8 -> v2i32. This gives targets an + /// opportunity to do custom lowering in such cases. + virtual SDValue lowerVectorCopyReg(bool ISABIRegCopy, SelectionDAG &DAG, + const SDLoc &DL, SDValue &Val, EVT Source, + EVT Dest, bool IsCopyTo = true) const { + return SDValue(); + }; + struct IntrinsicInfo { unsigned opc = 0; // target opcode EVT memVT; // memory VT @@ -1598,7 +1608,7 @@ class TargetLoweringBase { } /// Return the type of registers that this ValueType will eventually require. - MVT getRegisterType(LLVMContext &Context, EVT VT) const { + virtual MVT getRegisterType(LLVMContext &Context, EVT VT) const { if (VT.isSimple()) return getRegisterType(VT.getSimpleVT()); if (VT.isVector()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 5a227ba398e1c11..afa6caae889dabb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -400,6 +400,11 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL, if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits()) return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val); + if (auto TargetLowered = TLI.lowerVectorCopyReg(IsABIRegCopy, DAG, DL, Val, + PartEVT, ValueVT, false)) { + // Give targets a chance to custom lower mismatched sizes + return TargetLowered; + } // If the parts vector has more elements than the value vector, then we // have a vector widening case (e.g. <2 x float> -> <4 x float>). // Extract the elements we want. @@ -765,6 +770,10 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL, } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) { // Bitconvert vector->vector case. Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val); + } else if (SDValue TargetLowered = TLI.lowerVectorCopyReg( + IsABIRegCopy, DAG, DL, Val, ValueVT, BuiltVectorTy)) { + // Give targets a chance to custom lower mismatched sizes + Val = TargetLowered; } else { if (BuiltVectorTy.getVectorElementType().bitsGT( ValueVT.getVectorElementType())) { diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 85c9ed489e926ce..a7116518aaadd0d 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10513,7 +10513,7 @@ SDValue SITargetLowering::performAndCombine(SDNode *N, // performed. static const std::optional> calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, - unsigned Depth = 0) { + bool IsSigned = 0, unsigned Depth = 0) { // We may need to recursively traverse a series of SRLs if (Depth >= 6) return std::nullopt; @@ -10524,12 +10524,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, switch (Op->getOpcode()) { case ISD::TRUNCATE: { - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } case ISD::SIGN_EXTEND: case ISD::ZERO_EXTEND: case ISD::SIGN_EXTEND_INREG: { + IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND || + Op->getOpcode() == ISD::SIGN_EXTEND_INREG; SDValue NarrowOp = Op->getOperand(0); auto NarrowVT = NarrowOp.getValueType(); if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) { @@ -10542,7 +10545,8 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, if (SrcIndex >= NarrowByteWidth) return std::nullopt; - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } case ISD::SRA: @@ -10558,11 +10562,15 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, SrcIndex += BitShift / 8; - return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1); + return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned, + Depth + 1); } default: { - return ByteProvider::getSrc(Op, DestByte, SrcIndex); + if (auto L = dyn_cast(Op)) + IsSigned |= L->getExtensionType() == ISD::SEXTLOAD; + + return ByteProvider::getSrc(Op, DestByte, SrcIndex, IsSigned); } } llvm_unreachable("fully handled switch"); @@ -10576,7 +10584,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0, // performed. \p StartingIndex is the originally requested byte of the Or static const std::optional> calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, - unsigned StartingIndex = 0) { + unsigned StartingIndex = 0, bool IsSigned = 0) { // Finding Src tree of RHS of or typically requires at least 1 additional // depth if (Depth > 6) @@ -10591,11 +10599,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, switch (Op.getOpcode()) { case ISD::OR: { auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); if (!RHS) return std::nullopt; auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); if (!LHS) return std::nullopt; // A well formed Or will have two ByteProviders for each byte, one of which @@ -10626,7 +10634,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, return ByteProvider::getConstantZero(); } - return calculateSrcByte(Op->getOperand(0), StartingIndex, Index); + return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned); } case ISD::SRA: @@ -10651,7 +10659,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, // the SRL is Index + ByteShift return BytesProvided - ByteShift > Index ? calculateSrcByte(Op->getOperand(0), StartingIndex, - Index + ByteShift) + Index + ByteShift, IsSigned) : ByteProvider::getConstantZero(); } @@ -10672,7 +10680,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, return Index < ByteShift ? ByteProvider::getConstantZero() : calculateByteProvider(Op.getOperand(0), Index - ByteShift, - Depth + 1, StartingIndex); + Depth + 1, StartingIndex, IsSigned); } case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: @@ -10691,13 +10699,17 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, if (NarrowBitWidth % 8 != 0) return std::nullopt; uint64_t NarrowByteWidth = NarrowBitWidth / 8; + IsSigned |= Op->getOpcode() == ISD::SIGN_EXTEND || + Op->getOpcode() == ISD::SIGN_EXTEND_INREG || + Op->getOpcode() == ISD::AssertSext; if (Index >= NarrowByteWidth) return Op.getOpcode() == ISD::ZERO_EXTEND ? std::optional>( ByteProvider::getConstantZero()) : std::nullopt; - return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex); + return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex, + IsSigned); } case ISD::TRUNCATE: { @@ -10705,7 +10717,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, if (NarrowByteWidth >= Index) { return calculateByteProvider(Op.getOperand(0), Index, Depth + 1, - StartingIndex); + StartingIndex, IsSigned); } return std::nullopt; @@ -10713,13 +10725,14 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, case ISD::CopyFromReg: { if (BitWidth / 8 > Index) - return calculateSrcByte(Op, StartingIndex, Index); + return calculateSrcByte(Op, StartingIndex, Index, IsSigned); return std::nullopt; } case ISD::LOAD: { auto L = cast(Op.getNode()); + IsSigned |= L->getExtensionType() == ISD::SEXTLOAD; unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits(); if (NarrowBitWidth % 8 != 0) return std::nullopt; @@ -10736,7 +10749,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, } if (NarrowByteWidth > Index) { - return calculateSrcByte(Op, StartingIndex, Index); + return calculateSrcByte(Op, StartingIndex, Index, IsSigned); } return std::nullopt; @@ -10744,7 +10757,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, case ISD::BSWAP: return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1, - Depth + 1, StartingIndex); + Depth + 1, StartingIndex, IsSigned); case ISD::EXTRACT_VECTOR_ELT: { auto IdxOp = dyn_cast(Op->getOperand(1)); @@ -10759,7 +10772,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, } return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0), - StartingIndex, Index); + StartingIndex, Index, IsSigned); } case AMDGPUISD::PERM: { @@ -10775,9 +10788,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1); auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask; - return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex) - : ByteProvider( - ByteProvider::getConstantZero()); + return IdxMask != 0x0c + ? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned) + : ByteProvider( + ByteProvider::getConstantZero()); } default: { @@ -12587,11 +12601,7 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, auto MulIdx = isMul(LHS) ? 0 : 1; auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode(); - bool IsSigned = - MulOpcode == AMDGPUISD::MUL_I24 || - (MulOpcode == ISD::MUL && - TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && - !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); + std::optional IsSigned; SmallVector, 4> Src0s; SmallVector, 4> Src1s; SmallVector Src2s; @@ -12607,15 +12617,17 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, (MulOpcode == ISD::MUL && TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); - if (IterIsSigned != IsSigned) { - break; - } auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0)); if (!Src0) break; auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1)); if (!Src1) break; + IterIsSigned |= Src0->IsSigned || Src1->IsSigned; + if (!IsSigned) + IsSigned = IterIsSigned; + if (IterIsSigned != *IsSigned) + break; placeSources(*Src0, *Src1, Src0s, Src1s, I); auto AddIdx = 1 - MulIdx; // Allow the special case where add (add (mul24, 0), mul24) became -> @@ -12630,6 +12642,15 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1)); if (!Src1) break; + auto IterIsSigned = + MulOpcode == AMDGPUISD::MUL_I24 || + (MulOpcode == ISD::MUL && + TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() && + !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap()); + IterIsSigned |= Src0->IsSigned || Src1->IsSigned; + assert(IsSigned); + if (IterIsSigned != *IsSigned) + break; placeSources(*Src0, *Src1, Src0s, Src1s, I + 1); Src2s.push_back(DAG.getConstant(0, SL, MVT::i32)); ChainLength = I + 2; @@ -12695,18 +12716,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N, Src1 = resolveSources(DAG, SL, Src1s, false, true); } + assert(IsSigned); SDValue Src2 = - DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32); + DAG.getExtOrTrunc(*IsSigned, Src2s[ChainLength - 1], SL, MVT::i32); - SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4 - : Intrinsic::amdgcn_udot4, + SDValue IID = DAG.getTargetConstant(*IsSigned ? Intrinsic::amdgcn_sdot4 + : Intrinsic::amdgcn_udot4, SL, MVT::i64); assert(!VT.isVector()); auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0, Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1)); - return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT); + return DAG.getExtOrTrunc(*IsSigned, Dot, SL, VT); } if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG()) diff --git a/llvm/test/CodeGen/AMDGPU/idot4s.ll b/llvm/test/CodeGen/AMDGPU/idot4s.ll index 7edd24f12982ebd..e521039ce9ac838 100644 --- a/llvm/test/CodeGen/AMDGPU/idot4s.ll +++ b/llvm/test/CodeGen/AMDGPU/idot4s.ll @@ -143,7 +143,7 @@ define amdgpu_kernel void @idot4_acc32(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: s_load_b32 s2, s[0:1], 0x0 ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -352,7 +352,7 @@ define amdgpu_kernel void @idot4_acc16(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: global_load_i16 v3, v1, s[0:1] ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v2, v0, v3 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v2, v0, v3 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b16 v1, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -732,7 +732,7 @@ define amdgpu_kernel void @idot4_multiuse_mul1(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2) ; GFX11-DL-NEXT: v_mad_i32_i24 v2, v2, v3, s2 ; GFX11-DL-NEXT: v_mov_b32_e32 v3, 0 -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, v2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, v2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v3, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -922,7 +922,7 @@ define amdgpu_kernel void @idot4_acc32_vecMul(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: global_load_b32 v0, v0, s[6:7] ; GFX11-DL-NEXT: s_load_b32 s2, s[0:1], 0x0 ; GFX11-DL-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v1, v0, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -1356,7 +1356,7 @@ define amdgpu_kernel void @idot4_acc32_2ele(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: v_perm_b32 v0, v0, v0, 0xc0c0100 ; GFX11-DL-NEXT: s_waitcnt lgkmcnt(0) ; GFX11-DL-NEXT: s_delay_alu instid0(VALU_DEP_1) -; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v0, v1, s2 +; GFX11-DL-NEXT: v_dot4_i32_iu8 v0, v0, v1, s2 neg_lo:[1,1,0] ; GFX11-DL-NEXT: global_store_b32 v2, v0, s[0:1] ; GFX11-DL-NEXT: s_nop 0 ; GFX11-DL-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) @@ -1534,7 +1534,7 @@ define amdgpu_kernel void @idot4_acc32_3ele(ptr addrspace(1) %src1, ; GFX11-DL-NEXT: v_perm_b32 v0, v0,... |
d3d65a7
to
5dcf147
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be better off just writing an IR pass to pre-optimize these to legal integers
I'll look into the details of this |
For a demonstration of the issue -- |
And by IR pass I mean AMDGPUCodeGenPrepare |
Is this still needed? |
Abandoning in favor of #66838 |
In cases where we are copying vectors across basic block boundaries, we will emit CopyToReg / CopyFromReg pairs. For non-legal vector types, we typically scalarize and legalize each scalar, then emit a CopyTo / CopyFrom for each scalar. However, in some cases, we may be able to pack the vector into fewer registers. As an example, AMDGPU can pack a v4i8 into a single register by treating it as an i32 (rather than four registers, of each type i16).
This NFC patch introduces the target hooks to implement such functionality.