Skip to content

Commit 7fda1b7

Browse files
committed
[AMDGPU]: Allow combining into v_dot4
Differential Revision: https://reviews.llvm.org/D155995 Change-Id: I794f540217f0f84141338757b41b1be0493c7207
1 parent a09f09c commit 7fda1b7

File tree

4 files changed

+5077
-392
lines changed

4 files changed

+5077
-392
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 315 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12357,6 +12357,193 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
1235712357
return Accum;
1235812358
}
1235912359

12360+
// Collect the ultimate src of each of the mul24 node's operands, and confirm
12361+
// each operand is 8 bytes.
12362+
static std::optional<ByteProvider<SDValue>>
12363+
handleMulOperand(const SDValue &MulOperand) {
12364+
auto Byte0 = calculateByteProvider(MulOperand, 0, 0);
12365+
if (!Byte0 || Byte0->isConstantZero()) {
12366+
return std::nullopt;
12367+
}
12368+
auto Byte1 = calculateByteProvider(MulOperand, 1, 0);
12369+
if (Byte1 && !Byte1->isConstantZero()) {
12370+
return std::nullopt;
12371+
}
12372+
return Byte0;
12373+
}
12374+
12375+
static unsigned addPermMasks(unsigned First, unsigned Second) {
12376+
unsigned FirstCs = First & 0x0c0c0c0c;
12377+
unsigned SecondCs = Second & 0x0c0c0c0c;
12378+
unsigned FirstNoCs = First & ~0x0c0c0c0c;
12379+
unsigned SecondNoCs = Second & ~0x0c0c0c0c;
12380+
12381+
assert(FirstCs & 0xFF | SecondCs & 0xFF);
12382+
assert(FirstCs & 0xFF00 | SecondCs & 0xFF00);
12383+
assert(FirstCs & 0xFF0000 | SecondCs & 0xFF0000);
12384+
assert(FirstCs & 0xFF000000 | SecondCs & 0xFF000000);
12385+
12386+
return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
12387+
}
12388+
12389+
static void placeSources(ByteProvider<SDValue> &Src0,
12390+
ByteProvider<SDValue> &Src1,
12391+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12392+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12393+
int Step) {
12394+
12395+
assert(Src0.Src.has_value() && Src1.Src.has_value());
12396+
// Src0s and Src1s are empty, just place arbitrarily
12397+
if (Step == 0) {
12398+
Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12399+
Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12400+
return;
12401+
}
12402+
12403+
for (int BPI = 0; BPI < 2; BPI++) {
12404+
std::pair<ByteProvider<SDValue>, ByteProvider<SDValue>> BPP = {Src0, Src1};
12405+
if (BPI == 1) {
12406+
BPP = {Src1, Src0};
12407+
}
12408+
unsigned ZeroMask = 0x0c0c0c0c;
12409+
unsigned FMask = 0xFF << (8 * (3 - Step));
12410+
12411+
unsigned FirstMask =
12412+
BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12413+
unsigned SecondMask =
12414+
BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12415+
// Attempt to find Src vector which contains our SDValue, if so, add our
12416+
// perm mask to the existing one. If we are unable to find a match for the
12417+
// first SDValue, attempt to find match for the second.
12418+
int FirstGroup = -1;
12419+
for (int I = 0; I < 2; I++) {
12420+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12421+
I == 0 ? Src0s : Src1s;
12422+
auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12423+
return IterElt.first == *BPP.first.Src;
12424+
};
12425+
12426+
auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesFirst);
12427+
if (Match != Srcs.end()) {
12428+
Match->second = addPermMasks(FirstMask, Match->second);
12429+
FirstGroup = I;
12430+
break;
12431+
}
12432+
}
12433+
if (FirstGroup != -1) {
12434+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12435+
FirstGroup == 1 ? Src0s : Src1s;
12436+
auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12437+
return IterElt.first == *BPP.second.Src;
12438+
};
12439+
auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesSecond);
12440+
if (Match != Srcs.end()) {
12441+
Match->second = addPermMasks(SecondMask, Match->second);
12442+
} else
12443+
Srcs.push_back({*BPP.second.Src, SecondMask});
12444+
return;
12445+
}
12446+
}
12447+
12448+
// If we have made it here, then we could not find a match in Src0s or Src1s
12449+
// for either Src0 or Src1, so just place them arbitrarily.
12450+
12451+
unsigned ZeroMask = 0x0c0c0c0c;
12452+
unsigned FMask = 0xFF << (8 * (3 - Step));
12453+
12454+
Src0s.push_back(
12455+
{*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12456+
Src1s.push_back(
12457+
{*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12458+
12459+
return;
12460+
}
12461+
12462+
static SDValue
12463+
resolveSources(SelectionDAG &DAG, SDLoc SL,
12464+
SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12465+
bool IsSigned, bool IsAny) {
12466+
12467+
// If we just have one source, just permute it accordingly.
12468+
if (Srcs.size() == 1) {
12469+
auto Elt = Srcs.begin();
12470+
auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
12471+
12472+
// v_perm will produce the original value
12473+
if (Elt->second == 0x3020100)
12474+
return EltVal;
12475+
12476+
return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12477+
DAG.getConstant(Elt->second, SL, MVT::i32));
12478+
}
12479+
12480+
auto FirstElt = Srcs.begin();
12481+
auto SecondElt = std::next(FirstElt);
12482+
12483+
SmallVector<SDValue, 2> Perms;
12484+
12485+
// If we have multiple sources in the chain, combine them via perms (using
12486+
// calculated perm mask) and Ors.
12487+
while (true) {
12488+
auto FirstMask = FirstElt->second;
12489+
auto SecondMask = SecondElt->second;
12490+
12491+
unsigned FirstCs = FirstMask & 0x0c0c0c0c;
12492+
unsigned FirstPlusFour = FirstMask | 0x04040404;
12493+
// 0x0c + 0x04 = 0x10, so anding with 0x0F will produced 0x00 for any
12494+
// original 0x0C
12495+
FirstMask = (FirstPlusFour & 0x0F0F0F0F) | FirstCs;
12496+
12497+
auto PermMask = addPermMasks(FirstMask, SecondMask);
12498+
auto FirstVal =
12499+
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12500+
auto SecondVal =
12501+
DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
12502+
12503+
Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
12504+
SecondVal,
12505+
DAG.getConstant(PermMask, SL, MVT::i32)));
12506+
12507+
FirstElt = std::next(SecondElt);
12508+
if (FirstElt == Srcs.end())
12509+
break;
12510+
12511+
SecondElt = std::next(FirstElt);
12512+
// If we only have a FirstElt, then just combine that into the cumulative
12513+
// source node
12514+
if (SecondElt == Srcs.end()) {
12515+
auto EltVal =
12516+
DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12517+
12518+
Perms.push_back(
12519+
DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12520+
DAG.getConstant(FirstElt->second, SL, MVT::i32)));
12521+
break;
12522+
}
12523+
}
12524+
12525+
assert(Perms.size() == 1 || Perms.size() == 2);
12526+
return Perms.size() == 2
12527+
? DAG.getNode(ISD::OR, SL, MVT::i32, Perms[0], Perms[1])
12528+
: Perms[0];
12529+
}
12530+
12531+
static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12532+
unsigned ChainLength) {
12533+
for (auto &[EntryVal, EntryMask] : Srcs) {
12534+
EntryMask = EntryMask >> ((4 - ChainLength) * 8);
12535+
auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
12536+
EntryMask += ZeroMask;
12537+
}
12538+
}
12539+
12540+
static bool isMul(const SDValue Op) {
12541+
auto Opcode = Op.getOpcode();
12542+
12543+
return (Opcode == ISD::MUL || Opcode == AMDGPUISD::MUL_U24 ||
12544+
Opcode == AMDGPUISD::MUL_I24);
12545+
}
12546+
1236012547
SDValue SITargetLowering::performAddCombine(SDNode *N,
1236112548
DAGCombinerInfo &DCI) const {
1236212549
SelectionDAG &DAG = DCI.DAG;
@@ -12370,14 +12557,140 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
1237012557
if (SDValue Folded = tryFoldToMad64_32(N, DCI))
1237112558
return Folded;
1237212559
}
12373-
12374-
return SDValue();
1237512560
}
1237612561

1237712562
if (SDValue V = reassociateScalarOps(N, DAG)) {
1237812563
return V;
1237912564
}
1238012565

12566+
if ((isMul(LHS) || isMul(RHS)) && Subtarget->hasDot7Insts() &&
12567+
(Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
12568+
SDValue TempNode(N, 0);
12569+
auto MulIdx = isMul(LHS) ? 0 : 1;
12570+
12571+
auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
12572+
bool IsSigned =
12573+
MulOpcode == AMDGPUISD::MUL_I24 ||
12574+
(MulOpcode == ISD::MUL &&
12575+
TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
12576+
!TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
12577+
SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
12578+
SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
12579+
SmallVector<SDValue, 4> Src2s;
12580+
12581+
// Match the v_dot4 tree, while collecting src nodes.
12582+
int ChainLength = 0;
12583+
for (int I = 0; I < 4; I++) {
12584+
auto MulIdx = isMul(LHS) ? 0 : isMul(RHS) ? 1 : -1;
12585+
if (MulIdx == -1)
12586+
break;
12587+
auto IterIsSigned =
12588+
MulOpcode == AMDGPUISD::MUL_I24 ||
12589+
(MulOpcode == ISD::MUL &&
12590+
TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
12591+
!TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
12592+
if (IterIsSigned != IsSigned) {
12593+
break;
12594+
}
12595+
auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
12596+
if (!Src0)
12597+
break;
12598+
auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
12599+
if (!Src1)
12600+
break;
12601+
placeSources(*Src0, *Src1, Src0s, Src1s, I);
12602+
auto AddIdx = 1 - MulIdx;
12603+
// Allow the special case where add (add (mul24, 0), mul24) became ->
12604+
// add (mul24, mul24)
12605+
if (I == 2 && isMul(TempNode->getOperand(AddIdx))) {
12606+
Src2s.push_back(TempNode->getOperand(AddIdx));
12607+
auto Src0 =
12608+
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(0));
12609+
if (!Src0)
12610+
break;
12611+
auto Src1 =
12612+
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
12613+
if (!Src1)
12614+
break;
12615+
placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
12616+
Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
12617+
ChainLength = I + 2;
12618+
break;
12619+
}
12620+
12621+
TempNode = TempNode->getOperand(AddIdx);
12622+
Src2s.push_back(TempNode);
12623+
ChainLength = I + 1;
12624+
if (TempNode->getNumOperands() < 2)
12625+
break;
12626+
LHS = TempNode->getOperand(0);
12627+
RHS = TempNode->getOperand(1);
12628+
}
12629+
12630+
if (ChainLength < 2)
12631+
return SDValue();
12632+
12633+
// Masks were constructed with assumption that we would find a chain of
12634+
// length 4. If not, then we need to 0 out the MSB bits (via perm mask of
12635+
// 0x0c) so they do not affect dot calculation.
12636+
if (ChainLength < 4) {
12637+
fixMasks(Src0s, ChainLength);
12638+
fixMasks(Src1s, ChainLength);
12639+
}
12640+
12641+
SDValue Src0, Src1;
12642+
12643+
// If we are just using a single source for both, and have permuted the
12644+
// bytes consistently, we can just use the sources without permuting
12645+
// (commutation)
12646+
bool UseOriginalSrc = false;
12647+
if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
12648+
Src0s.begin()->second == Src1s.begin()->second &&
12649+
Src0s.begin()->first.getValueSizeInBits() == 32 &&
12650+
Src1s.begin()->first.getValueSizeInBits() == 32) {
12651+
SmallVector<unsigned, 4> SrcBytes;
12652+
auto Src0Mask = Src0s.begin()->second;
12653+
SrcBytes.push_back(Src0Mask & 0xFF000000);
12654+
bool UniqueEntries = true;
12655+
for (auto I = 1; I < 4; I++) {
12656+
auto NextByte = Src0Mask & (0xFF << ((3 - I) * 8));
12657+
12658+
if (is_contained(SrcBytes, NextByte)) {
12659+
UniqueEntries = false;
12660+
break;
12661+
}
12662+
SrcBytes.push_back(NextByte);
12663+
}
12664+
12665+
if (UniqueEntries) {
12666+
UseOriginalSrc = true;
12667+
// Must be 32 bits to enter above conditional
12668+
assert(Src0s.begin()->first.getValueSizeInBits() == 32);
12669+
assert(Src1s.begin()->first.getValueSizeInBits() == 32);
12670+
Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
12671+
Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
12672+
}
12673+
}
12674+
12675+
if (!UseOriginalSrc) {
12676+
Src0 = resolveSources(DAG, SL, Src0s, false, true);
12677+
Src1 = resolveSources(DAG, SL, Src1s, false, true);
12678+
}
12679+
12680+
SDValue Src2 =
12681+
DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
12682+
12683+
SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
12684+
: Intrinsic::amdgcn_udot4,
12685+
SL, MVT::i64);
12686+
12687+
assert(!VT.isVector());
12688+
auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
12689+
Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));
12690+
12691+
return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
12692+
}
12693+
1238112694
if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
1238212695
return SDValue();
1238312696

llvm/test/CodeGen/AMDGPU/idot2.ll

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,18 +2823,18 @@ define amdgpu_kernel void @notsdot2_sext8(ptr addrspace(1) %src1,
28232823
; GFX9-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
28242824
; GFX9-DL-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x34
28252825
; GFX9-DL-NEXT: v_lshlrev_b32_e32 v0, 1, v0
2826+
; GFX9-DL-NEXT: s_mov_b32 s1, 0xc0c0001
28262827
; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
28272828
; GFX9-DL-NEXT: global_load_ushort v1, v0, s[4:5]
28282829
; GFX9-DL-NEXT: global_load_ushort v2, v0, s[6:7]
28292830
; GFX9-DL-NEXT: s_load_dword s0, s[2:3], 0x0
28302831
; GFX9-DL-NEXT: v_mov_b32_e32 v0, 0
2832+
; GFX9-DL-NEXT: s_waitcnt vmcnt(1)
2833+
; GFX9-DL-NEXT: v_perm_b32 v1, v1, v1, s1
28312834
; GFX9-DL-NEXT: s_waitcnt vmcnt(0)
2832-
; GFX9-DL-NEXT: v_mul_i32_i24_sdwa v3, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
2833-
; GFX9-DL-NEXT: v_lshrrev_b16_e32 v1, 8, v1
2834-
; GFX9-DL-NEXT: v_lshrrev_b16_e32 v2, 8, v2
2835-
; GFX9-DL-NEXT: v_mul_i32_i24_sdwa v1, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
2835+
; GFX9-DL-NEXT: v_perm_b32 v2, v2, v2, s1
28362836
; GFX9-DL-NEXT: s_waitcnt lgkmcnt(0)
2837-
; GFX9-DL-NEXT: v_add3_u32 v1, v1, s0, v3
2837+
; GFX9-DL-NEXT: v_dot4_i32_i8 v1, v2, v1, s0
28382838
; GFX9-DL-NEXT: global_store_dword v0, v1, s[2:3]
28392839
; GFX9-DL-NEXT: s_endpgm
28402840
;
@@ -2843,21 +2843,20 @@ define amdgpu_kernel void @notsdot2_sext8(ptr addrspace(1) %src1,
28432843
; GFX10-DL-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x24
28442844
; GFX10-DL-NEXT: v_lshlrev_b32_e32 v0, 1, v0
28452845
; GFX10-DL-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x34
2846+
; GFX10-DL-NEXT: v_mov_b32_e32 v3, 0
28462847
; GFX10-DL-NEXT: s_waitcnt lgkmcnt(0)
28472848
; GFX10-DL-NEXT: s_clause 0x1
28482849
; GFX10-DL-NEXT: global_load_ushort v1, v0, s[4:5]
28492850
; GFX10-DL-NEXT: global_load_ushort v2, v0, s[6:7]
28502851
; GFX10-DL-NEXT: s_load_dword s2, s[0:1], 0x0
28512852
; GFX10-DL-NEXT: s_waitcnt vmcnt(1)
2852-
; GFX10-DL-NEXT: v_lshrrev_b16 v0, 8, v1
2853+
; GFX10-DL-NEXT: v_perm_b32 v0, v1, v1, 0xc0c0001
28532854
; GFX10-DL-NEXT: s_waitcnt vmcnt(0)
2854-
; GFX10-DL-NEXT: v_lshrrev_b16 v3, 8, v2
2855-
; GFX10-DL-NEXT: v_mul_i32_i24_sdwa v1, sext(v2), sext(v1) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
2856-
; GFX10-DL-NEXT: v_mov_b32_e32 v2, 0
2857-
; GFX10-DL-NEXT: v_mul_i32_i24_sdwa v0, sext(v3), sext(v0) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0
2855+
; GFX10-DL-NEXT: v_perm_b32 v1, v2, v2, 0xc0c0001
28582856
; GFX10-DL-NEXT: s_waitcnt lgkmcnt(0)
2859-
; GFX10-DL-NEXT: v_add3_u32 v0, v0, s2, v1
2860-
; GFX10-DL-NEXT: global_store_dword v2, v0, s[0:1]
2857+
; GFX10-DL-NEXT: v_mov_b32_e32 v2, s2
2858+
; GFX10-DL-NEXT: v_dot4c_i32_i8_e32 v2, v1, v0
2859+
; GFX10-DL-NEXT: global_store_dword v3, v2, s[0:1]
28612860
; GFX10-DL-NEXT: s_endpgm
28622861
ptr addrspace(1) %src2,
28632862
ptr addrspace(1) nocapture %dst) {

0 commit comments

Comments
 (0)