Skip to content

Commit daa4728

Browse files
authored
[AMDGPU] Add CodeGen support for GFX12 s_mul_u64 (#75825)
1 parent 23e03a8 commit daa4728

18 files changed

+2632
-628
lines changed

llvm/lib/Target/AMDGPU/AMDGPUCombine.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ def foldable_fneg : GICombineRule<
104104
[{ return Helper.matchFoldableFneg(*${ffn}, ${matchinfo}); }]),
105105
(apply [{ Helper.applyFoldableFneg(*${ffn}, ${matchinfo}); }])>;
106106

107+
// Detects s_mul_u64 instructions whose higher bits are zero/sign extended.
108+
def smulu64 : GICombineRule<
109+
(defs root:$smul, unsigned_matchinfo:$matchinfo),
110+
(match (wip_match_opcode G_MUL):$smul,
111+
[{ return matchCombine_s_mul_u64(*${smul}, ${matchinfo}); }]),
112+
(apply [{ applyCombine_s_mul_u64(*${smul}, ${matchinfo}); }])>;
113+
107114
def sign_exension_in_reg_matchdata : GIDefMatchData<"MachineInstr *">;
108115

109116
def sign_extension_in_reg : GICombineRule<
@@ -149,7 +156,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
149156
"AMDGPUPostLegalizerCombinerImpl",
150157
[all_combines, gfx6gfx7_combines, gfx8_combines,
151158
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
152-
rcp_sqrt_to_rsq, sign_extension_in_reg]> {
159+
rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
153160
let CombineAllMethodName = "tryCombineAllImpl";
154161
}
155162

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -701,13 +701,23 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
701701
.maxScalar(0, S32);
702702
}
703703

704-
getActionDefinitionsBuilder(G_MUL)
705-
.legalFor({S32, S16, V2S16})
706-
.clampMaxNumElementsStrict(0, S16, 2)
707-
.scalarize(0)
708-
.minScalar(0, S16)
709-
.widenScalarToNextMultipleOf(0, 32)
710-
.custom();
704+
if (ST.hasScalarSMulU64()) {
705+
getActionDefinitionsBuilder(G_MUL)
706+
.legalFor({S64, S32, S16, V2S16})
707+
.clampMaxNumElementsStrict(0, S16, 2)
708+
.scalarize(0)
709+
.minScalar(0, S16)
710+
.widenScalarToNextMultipleOf(0, 32)
711+
.custom();
712+
} else {
713+
getActionDefinitionsBuilder(G_MUL)
714+
.legalFor({S32, S16, V2S16})
715+
.clampMaxNumElementsStrict(0, S16, 2)
716+
.scalarize(0)
717+
.minScalar(0, S16)
718+
.widenScalarToNextMultipleOf(0, 32)
719+
.custom();
720+
}
711721
assert(ST.hasMad64_32());
712722

713723
getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT, G_SADDSAT, G_SSUBSAT})

llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
104104
void applyCombineSignExtendInReg(MachineInstr &MI,
105105
MachineInstr *&MatchInfo) const;
106106

107+
// Find the s_mul_u64 instructions where the higher bits are either
108+
// zero-extended or sign-extended.
109+
bool matchCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;
110+
// Replace the s_mul_u64 instructions with S_MUL_I64_I32_PSEUDO if the higher
111+
// 33 bits are sign extended and with S_MUL_U64_U32_PSEUDO if the higher 32
112+
// bits are zero extended.
113+
void applyCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;
114+
107115
private:
108116
#define GET_GICOMBINER_CLASS_MEMBERS
109117
#define AMDGPUSubtarget GCNSubtarget
@@ -419,6 +427,32 @@ void AMDGPUPostLegalizerCombinerImpl::applyCombineSignExtendInReg(
419427
MI.eraseFromParent();
420428
}
421429

430+
bool AMDGPUPostLegalizerCombinerImpl::matchCombine_s_mul_u64(
431+
MachineInstr &MI, unsigned &NewOpcode) const {
432+
Register Src0 = MI.getOperand(1).getReg();
433+
Register Src1 = MI.getOperand(2).getReg();
434+
if (MRI.getType(Src0) != LLT::scalar(64))
435+
return false;
436+
437+
if (KB->getKnownBits(Src1).countMinLeadingZeros() >= 32 &&
438+
KB->getKnownBits(Src0).countMinLeadingZeros() >= 32) {
439+
NewOpcode = AMDGPU::G_AMDGPU_S_MUL_U64_U32;
440+
return true;
441+
}
442+
443+
if (KB->computeNumSignBits(Src1) >= 33 &&
444+
KB->computeNumSignBits(Src0) >= 33) {
445+
NewOpcode = AMDGPU::G_AMDGPU_S_MUL_I64_I32;
446+
return true;
447+
}
448+
return false;
449+
}
450+
451+
void AMDGPUPostLegalizerCombinerImpl::applyCombine_s_mul_u64(
452+
MachineInstr &MI, unsigned &NewOpcode) const {
453+
Helper.replaceOpcodeWith(MI, NewOpcode);
454+
}
455+
422456
// Pass boilerplate
423457
// ================
424458

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,74 @@ bool AMDGPURegisterBankInfo::foldInsertEltToCmpSelect(
20942094
return true;
20952095
}
20962096

2097+
// Break s_mul_u64 into 32-bit vector operations.
2098+
void AMDGPURegisterBankInfo::applyMappingSMULU64(
2099+
MachineIRBuilder &B, const OperandsMapper &OpdMapper) const {
2100+
SmallVector<Register, 2> DefRegs(OpdMapper.getVRegs(0));
2101+
SmallVector<Register, 2> Src0Regs(OpdMapper.getVRegs(1));
2102+
SmallVector<Register, 2> Src1Regs(OpdMapper.getVRegs(2));
2103+
2104+
// All inputs are SGPRs, nothing special to do.
2105+
if (DefRegs.empty()) {
2106+
assert(Src0Regs.empty() && Src1Regs.empty());
2107+
applyDefaultMapping(OpdMapper);
2108+
return;
2109+
}
2110+
2111+
assert(DefRegs.size() == 2);
2112+
assert(Src0Regs.size() == Src1Regs.size() &&
2113+
(Src0Regs.empty() || Src0Regs.size() == 2));
2114+
2115+
MachineRegisterInfo &MRI = OpdMapper.getMRI();
2116+
MachineInstr &MI = OpdMapper.getMI();
2117+
Register DstReg = MI.getOperand(0).getReg();
2118+
LLT HalfTy = LLT::scalar(32);
2119+
2120+
// Depending on where the source registers came from, the generic code may
2121+
// have decided to split the inputs already or not. If not, we still need to
2122+
// extract the values.
2123+
2124+
if (Src0Regs.empty())
2125+
split64BitValueForMapping(B, Src0Regs, HalfTy, MI.getOperand(1).getReg());
2126+
else
2127+
setRegsToType(MRI, Src0Regs, HalfTy);
2128+
2129+
if (Src1Regs.empty())
2130+
split64BitValueForMapping(B, Src1Regs, HalfTy, MI.getOperand(2).getReg());
2131+
else
2132+
setRegsToType(MRI, Src1Regs, HalfTy);
2133+
2134+
setRegsToType(MRI, DefRegs, HalfTy);
2135+
2136+
// The multiplication is done as follows:
2137+
//
2138+
// Op1H Op1L
2139+
// * Op0H Op0L
2140+
// --------------------
2141+
// Op1H*Op0L Op1L*Op0L
2142+
// + Op1H*Op0H Op1L*Op0H
2143+
// -----------------------------------------
2144+
// (Op1H*Op0L + Op1L*Op0H + carry) Op1L*Op0L
2145+
//
2146+
// We drop Op1H*Op0H because the result of the multiplication is a 64-bit
2147+
// value and that would overflow.
2148+
// The low 32-bit value is Op1L*Op0L.
2149+
// The high 32-bit value is Op1H*Op0L + Op1L*Op0H + carry (from
2150+
// Op1L*Op0L).
2151+
2152+
ApplyRegBankMapping ApplyBank(B, *this, MRI, &AMDGPU::VGPRRegBank);
2153+
2154+
Register Hi = B.buildUMulH(HalfTy, Src0Regs[0], Src1Regs[0]).getReg(0);
2155+
Register MulLoHi = B.buildMul(HalfTy, Src0Regs[0], Src1Regs[1]).getReg(0);
2156+
Register Add = B.buildAdd(HalfTy, Hi, MulLoHi).getReg(0);
2157+
Register MulHiLo = B.buildMul(HalfTy, Src0Regs[1], Src1Regs[0]).getReg(0);
2158+
B.buildAdd(DefRegs[1], Add, MulHiLo);
2159+
B.buildMul(DefRegs[0], Src0Regs[0], Src1Regs[0]);
2160+
2161+
MRI.setRegBank(DstReg, AMDGPU::VGPRRegBank);
2162+
MI.eraseFromParent();
2163+
}
2164+
20972165
void AMDGPURegisterBankInfo::applyMappingImpl(
20982166
MachineIRBuilder &B, const OperandsMapper &OpdMapper) const {
20992167
MachineInstr &MI = OpdMapper.getMI();
@@ -2394,13 +2462,21 @@ void AMDGPURegisterBankInfo::applyMappingImpl(
23942462
Register DstReg = MI.getOperand(0).getReg();
23952463
LLT DstTy = MRI.getType(DstReg);
23962464

2465+
// Special case for s_mul_u64. There is not a vector equivalent of
2466+
// s_mul_u64. Hence, we have to break down s_mul_u64 into 32-bit vector
2467+
// multiplications.
2468+
if (Opc == AMDGPU::G_MUL && DstTy.getSizeInBits() == 64) {
2469+
applyMappingSMULU64(B, OpdMapper);
2470+
return;
2471+
}
2472+
23972473
// 16-bit operations are VALU only, but can be promoted to 32-bit SALU.
23982474
// Packed 16-bit operations need to be scalarized and promoted.
23992475
if (DstTy != LLT::scalar(16) && DstTy != LLT::fixed_vector(2, 16))
24002476
break;
24012477

24022478
const RegisterBank *DstBank =
2403-
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
2479+
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
24042480
if (DstBank == &AMDGPU::VGPRRegBank)
24052481
break;
24062482

@@ -2451,6 +2527,72 @@ void AMDGPURegisterBankInfo::applyMappingImpl(
24512527

24522528
return;
24532529
}
2530+
case AMDGPU::G_AMDGPU_S_MUL_I64_I32:
2531+
case AMDGPU::G_AMDGPU_S_MUL_U64_U32: {
2532+
// This is a special case for s_mul_u64. We use
2533+
// G_AMDGPU_S_MUL_I64_I32 opcode to represent an s_mul_u64 operation
2534+
// where the 33 higher bits are sign-extended and
2535+
// G_AMDGPU_S_MUL_U64_U32 opcode to represent an s_mul_u64 operation
2536+
// where the 32 higher bits are zero-extended. In case scalar registers are
2537+
// selected, both opcodes are lowered as s_mul_u64. If the vector registers
2538+
// are selected, then G_AMDGPU_S_MUL_I64_I32 and
2539+
// G_AMDGPU_S_MUL_U64_U32 are lowered with a vector mad instruction.
2540+
2541+
// Insert basic copies.
2542+
applyDefaultMapping(OpdMapper);
2543+
2544+
Register DstReg = MI.getOperand(0).getReg();
2545+
Register SrcReg0 = MI.getOperand(1).getReg();
2546+
Register SrcReg1 = MI.getOperand(2).getReg();
2547+
const LLT S32 = LLT::scalar(32);
2548+
const LLT S64 = LLT::scalar(64);
2549+
assert(MRI.getType(DstReg) == S64 && "This is a special case for s_mul_u64 "
2550+
"that handles only 64-bit operands.");
2551+
const RegisterBank *DstBank =
2552+
OpdMapper.getInstrMapping().getOperandMapping(0).BreakDown[0].RegBank;
2553+
2554+
// Replace G_AMDGPU_S_MUL_I64_I32 and G_AMDGPU_S_MUL_U64_U32
2555+
// with s_mul_u64 operation.
2556+
if (DstBank == &AMDGPU::SGPRRegBank) {
2557+
MI.setDesc(TII->get(AMDGPU::S_MUL_U64));
2558+
MRI.setRegClass(DstReg, &AMDGPU::SGPR_64RegClass);
2559+
MRI.setRegClass(SrcReg0, &AMDGPU::SGPR_64RegClass);
2560+
MRI.setRegClass(SrcReg1, &AMDGPU::SGPR_64RegClass);
2561+
return;
2562+
}
2563+
2564+
// Replace G_AMDGPU_S_MUL_I64_I32 and G_AMDGPU_S_MUL_U64_U32
2565+
// with a vector mad.
2566+
assert(MRI.getRegBankOrNull(DstReg) == &AMDGPU::VGPRRegBank &&
2567+
"The destination operand should be in vector registers.");
2568+
2569+
DebugLoc DL = MI.getDebugLoc();
2570+
2571+
// Extract the lower subregister from the first operand.
2572+
Register Op0L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
2573+
MRI.setRegClass(Op0L, &AMDGPU::VGPR_32RegClass);
2574+
MRI.setType(Op0L, S32);
2575+
B.buildTrunc(Op0L, SrcReg0);
2576+
2577+
// Extract the lower subregister from the second operand.
2578+
Register Op1L = MRI.createVirtualRegister(&AMDGPU::VGPR_32RegClass);
2579+
MRI.setRegClass(Op1L, &AMDGPU::VGPR_32RegClass);
2580+
MRI.setType(Op1L, S32);
2581+
B.buildTrunc(Op1L, SrcReg1);
2582+
2583+
unsigned NewOpc = Opc == AMDGPU::G_AMDGPU_S_MUL_U64_U32
2584+
? AMDGPU::G_AMDGPU_MAD_U64_U32
2585+
: AMDGPU::G_AMDGPU_MAD_I64_I32;
2586+
2587+
MachineIRBuilder B(MI);
2588+
Register Zero64 = B.buildConstant(S64, 0).getReg(0);
2589+
MRI.setRegClass(Zero64, &AMDGPU::VReg_64RegClass);
2590+
Register CarryOut = MRI.createVirtualRegister(&AMDGPU::VReg_64RegClass);
2591+
MRI.setRegClass(CarryOut, &AMDGPU::VReg_64RegClass);
2592+
B.buildInstr(NewOpc, {DstReg, CarryOut}, {Op0L, Op1L, Zero64});
2593+
MI.eraseFromParent();
2594+
return;
2595+
}
24542596
case AMDGPU::G_SEXT_INREG: {
24552597
SmallVector<Register, 2> SrcRegs(OpdMapper.getVRegs(1));
24562598
if (SrcRegs.empty())
@@ -3669,7 +3811,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
36693811

36703812
case AMDGPU::G_AND:
36713813
case AMDGPU::G_OR:
3672-
case AMDGPU::G_XOR: {
3814+
case AMDGPU::G_XOR:
3815+
case AMDGPU::G_MUL: {
36733816
unsigned Size = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
36743817
if (Size == 1) {
36753818
const RegisterBank *DstBank
@@ -3737,7 +3880,6 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
37373880
case AMDGPU::G_PTRMASK:
37383881
case AMDGPU::G_ADD:
37393882
case AMDGPU::G_SUB:
3740-
case AMDGPU::G_MUL:
37413883
case AMDGPU::G_SHL:
37423884
case AMDGPU::G_LSHR:
37433885
case AMDGPU::G_ASHR:
@@ -3755,6 +3897,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
37553897
case AMDGPU::G_SHUFFLE_VECTOR:
37563898
case AMDGPU::G_SBFX:
37573899
case AMDGPU::G_UBFX:
3900+
case AMDGPU::G_AMDGPU_S_MUL_I64_I32:
3901+
case AMDGPU::G_AMDGPU_S_MUL_U64_U32:
37583902
if (isSALUMapping(MI))
37593903
return getDefaultMappingSOP(MI);
37603904
return getDefaultMappingVOP(MI);

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class AMDGPURegisterBankInfo final : public AMDGPUGenRegisterBankInfo {
8484
bool applyMappingMAD_64_32(MachineIRBuilder &B,
8585
const OperandsMapper &OpdMapper) const;
8686

87+
void applyMappingSMULU64(MachineIRBuilder &B,
88+
const OperandsMapper &OpdMapper) const;
89+
8790
Register handleD16VData(MachineIRBuilder &B, MachineRegisterInfo &MRI,
8891
Register Reg) const;
8992

llvm/lib/Target/AMDGPU/GCNSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
683683

684684
bool hasScalarAddSub64() const { return getGeneration() >= GFX12; }
685685

686+
bool hasScalarSMulU64() const { return getGeneration() >= GFX12; }
687+
686688
bool hasUnpackedD16VMem() const {
687689
return HasUnpackedD16VMem;
688690
}

0 commit comments

Comments
 (0)