Skip to content

[AMDGPU][True16] Support source DPP operands. #79025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
}

bool isVRegWithInputMods() const;
bool isT16VRegWithInputMods() const;
template <bool IsFake16> bool isT16VRegWithInputMods() const;

bool isSDWAOperand(MVT type) const;
bool isSDWAFP16Operand() const;
Expand Down Expand Up @@ -2054,8 +2054,9 @@ bool AMDGPUOperand::isVRegWithInputMods() const {
AsmParser->getFeatureBits()[AMDGPU::FeatureDPALU_DPP]);
}

bool AMDGPUOperand::isT16VRegWithInputMods() const {
return isRegClass(AMDGPU::VGPR_32_Lo128RegClassID);
template <bool IsFake16> bool AMDGPUOperand::isT16VRegWithInputMods() const {
return isRegClass(IsFake16 ? AMDGPU::VGPR_32_Lo128RegClassID
: AMDGPU::VGPR_16_Lo128RegClassID);
}

bool AMDGPUOperand::isSDWAOperand(MVT type) const {
Expand Down
17 changes: 9 additions & 8 deletions llvm/lib/Target/AMDGPU/BUFInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ class MUBUF_Load_Pseudo <string opName,
list<dag> pattern=[],
// Workaround bug bz30254
int addrKindCopy = addrKind,
RegisterClass vdata_rc = getVregSrcForVT<vdata_vt>.ret,
RegisterOperand vdata_op = getLdStVDataRegisterOperand<vdata_rc, isTFE>.ret>
RegisterOperand vdata_rc = getVregSrcForVT<vdata_vt>.ret,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is not a RegisterClass anymore, so the name vdata_rc doesn't make much sense. This field only looks to be used in calculating vdata_op, and can probably be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refined in bb3a515.

RegisterOperand vdata_op = getLdStVDataRegisterOperand<vdata_rc.RegClass, isTFE>.ret>
: MUBUF_Pseudo<opName,
!if(!or(isLds, isLdsOpc), (outs), (outs vdata_op:$vdata)),
!con(getMUBUFIns<addrKindCopy, [], isTFE, hasGFX12Enc>.ret,
Expand Down Expand Up @@ -601,7 +601,7 @@ class MUBUF_Store_Pseudo <string opName,
int addrKindCopy = addrKind>
: MUBUF_Pseudo<opName,
(outs),
getMUBUFIns<addrKindCopy, [getVregSrcForVT<store_vt>.ret], isTFE, hasGFX12Enc>.ret,
getMUBUFIns<addrKindCopy, [getVregSrcForVT<store_vt>.ret.RegClass], isTFE, hasGFX12Enc>.ret,
getMUBUFAsmOps<addrKindCopy, 0, 0, isTFE>.ret,
pattern>,
MUBUF_SetupAddr<addrKindCopy> {
Expand Down Expand Up @@ -1569,27 +1569,28 @@ multiclass BufferAtomicCmpSwapPat_Common<ValueType vt, ValueType data_vt, string
# !if(!eq(RtnMode, "ret"), "", "_noret")
# "_" # vt.Size);
defvar InstSuffix = !if(!eq(RtnMode, "ret"), "_RTN", "");
defvar data_vt_RC = getVregSrcForVT<data_vt>.ret.RegClass;

let AddedComplexity = !if(!eq(RtnMode, "ret"), 0, 1) in {
defvar OffsetResDag = (!cast<MUBUF_Pseudo>(Inst # "_OFFSET" # InstSuffix)
getVregSrcForVT<data_vt>.ret:$vdata_in, SReg_128:$srsrc, SCSrc_b32:$soffset,
data_vt_RC:$vdata_in, SReg_128:$srsrc, SCSrc_b32:$soffset,
offset:$offset);
def : GCNPat<
(vt (Op (MUBUFOffset v4i32:$srsrc, i32:$soffset, i32:$offset), data_vt:$vdata_in)),
!if(!eq(RtnMode, "ret"),
(EXTRACT_SUBREG (vt (COPY_TO_REGCLASS OffsetResDag, getVregSrcForVT<data_vt>.ret)),
(EXTRACT_SUBREG (vt (COPY_TO_REGCLASS OffsetResDag, data_vt_RC)),
!if(!eq(vt, i32), sub0, sub0_sub1)),
OffsetResDag)
>;

defvar Addr64ResDag = (!cast<MUBUF_Pseudo>(Inst # "_ADDR64" # InstSuffix)
getVregSrcForVT<data_vt>.ret:$vdata_in, VReg_64:$vaddr, SReg_128:$srsrc,
data_vt_RC:$vdata_in, VReg_64:$vaddr, SReg_128:$srsrc,
SCSrc_b32:$soffset, offset:$offset);
def : GCNPat<
(vt (Op (MUBUFAddr64 v4i32:$srsrc, i64:$vaddr, i32:$soffset, i32:$offset),
data_vt:$vdata_in)),
!if(!eq(RtnMode, "ret"),
(EXTRACT_SUBREG (vt (COPY_TO_REGCLASS Addr64ResDag, getVregSrcForVT<data_vt>.ret)),
(EXTRACT_SUBREG (vt (COPY_TO_REGCLASS Addr64ResDag, data_vt_RC)),
!if(!eq(vt, i32), sub0, sub0_sub1)),
Addr64ResDag)
>;
Expand Down Expand Up @@ -1820,7 +1821,7 @@ multiclass SIBufferAtomicCmpSwapPat_Common<ValueType vt, ValueType data_vt, stri
(extract_cpol_set_glc $auxiliary),
(extract_cpol $auxiliary));
defvar SrcRC = getVregSrcForVT<vt>.ret;
defvar DataRC = getVregSrcForVT<data_vt>.ret;
defvar DataRC = getVregSrcForVT<data_vt>.ret.RegClass;
defvar SubLo = !if(!eq(vt, i32), sub0, sub0_sub1);
defvar SubHi = !if(!eq(vt, i32), sub1, sub2_sub3);

Expand Down
43 changes: 42 additions & 1 deletion llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,48 @@ void AMDGPUMCCodeEmitter::getMachineOpValue(const MCInst &MI,
void AMDGPUMCCodeEmitter::getMachineOpValueT16(
const MCInst &MI, unsigned OpNo, APInt &Op,
SmallVectorImpl<MCFixup> &Fixups, const MCSubtargetInfo &STI) const {
llvm_unreachable("TODO: Implement getMachineOpValueT16().");
const MCOperand &MO = MI.getOperand(OpNo);
if (MO.isReg()) {
unsigned Enc = MRI.getEncodingValue(MO.getReg());
unsigned Idx = Enc & AMDGPU::HWEncoding::REG_IDX_MASK;
bool IsVGPR = Enc & AMDGPU::HWEncoding::IS_VGPR_OR_AGPR;
Op = Idx | (IsVGPR << 8);
return;
}
getMachineOpValueCommon(MI, MO, OpNo, Op, Fixups, STI);
// VGPRs include the suffix/op_sel bit in the register encoding, but
// immediates and SGPRs include it in src_modifiers. Therefore, copy the
// op_sel bit from the src operands into src_modifier operands if Op is
// src_modifiers and the corresponding src is a VGPR
int SrcMOIdx = -1;
assert(OpNo < INT_MAX);
if ((int)OpNo == AMDGPU::getNamedOperandIdx(MI.getOpcode(),
AMDGPU::OpName::src0_modifiers)) {
SrcMOIdx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0);
int VDstMOIdx =
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::vdst);
if (VDstMOIdx != -1) {
auto DstReg = MI.getOperand(VDstMOIdx).getReg();
if (AMDGPU::isHi(DstReg, MRI))
Op |= SISrcMods::DST_OP_SEL;
}
} else if ((int)OpNo == AMDGPU::getNamedOperandIdx(
MI.getOpcode(), AMDGPU::OpName::src1_modifiers))
SrcMOIdx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src1);
else if ((int)OpNo == AMDGPU::getNamedOperandIdx(
MI.getOpcode(), AMDGPU::OpName::src2_modifiers))
SrcMOIdx = AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src2);
if (SrcMOIdx == -1)
return;

const MCOperand &SrcMO = MI.getOperand(SrcMOIdx);
if (!SrcMO.isReg())
return;
auto SrcReg = SrcMO.getReg();
if (AMDGPU::isSGPR(SrcReg, &MRI))
return;
if (AMDGPU::isHi(SrcReg, MRI))
Op |= SISrcMods::OP_SEL_0;
}

void AMDGPUMCCodeEmitter::getMachineOpValueT16Lo128(
Expand Down
113 changes: 55 additions & 58 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1223,17 +1223,20 @@ def FPVRegInputModsMatchClass : AsmOperandClass {
let PredicateMethod = "isVRegWithInputMods";
}

def FPT16VRegInputModsMatchClass : AsmOperandClass {
let Name = "T16VRegWithFPInputMods";
class FPT16VRegInputModsMatchClass<bit IsFake16> : AsmOperandClass {
let Name = !if(IsFake16, "Fake16VRegWithFPInputMods",
"T16VRegWithFPInputMods");
let ParserMethod = "parseRegWithFPInputMods";
let PredicateMethod = "isT16VRegWithInputMods";
let PredicateMethod = "isT16VRegWithInputMods<" #
!if(IsFake16, "true", "false") # ">";
}

def FPVRegInputMods : InputMods <FPVRegInputModsMatchClass> {
let PrintMethod = "printOperandAndFPInputMods";
}

def FPT16VRegInputMods : InputMods <FPT16VRegInputModsMatchClass> {
class FPT16VRegInputMods<bit IsFake16>
: InputMods <FPT16VRegInputModsMatchClass<IsFake16>> {
let PrintMethod = "printOperandAndFPInputMods";
}

Expand Down Expand Up @@ -1265,13 +1268,16 @@ def IntVRegInputModsMatchClass : AsmOperandClass {
let PredicateMethod = "isVRegWithInputMods";
}

def IntT16VRegInputModsMatchClass : AsmOperandClass {
let Name = "T16VRegWithIntInputMods";
class IntT16VRegInputModsMatchClass<bit IsFake16> : AsmOperandClass {
let Name = !if(IsFake16, "Fake16VRegWithIntInputMods",
"T16VRegWithIntInputMods");
let ParserMethod = "parseRegWithIntInputMods";
let PredicateMethod = "isT16VRegWithInputMods";
let PredicateMethod = "isT16VRegWithInputMods<" #
!if(IsFake16, "true", "false") # ">";
}

def IntT16VRegInputMods : InputMods <IntT16VRegInputModsMatchClass> {
class IntT16VRegInputMods<bit IsFake16>
: InputMods <IntT16VRegInputModsMatchClass<IsFake16>> {
let PrintMethod = "printOperandAndIntInputMods";
}

Expand Down Expand Up @@ -1510,25 +1516,17 @@ class getSOPSrcForVT<ValueType VT> {
}

// Returns the vreg register class to use for source operand given VT
class getVregSrcForVT<ValueType VT> {
RegisterClass ret = !if(!eq(VT.Size, 128), VReg_128,
!if(!eq(VT.Size, 96), VReg_96,
!if(!eq(VT.Size, 64), VReg_64,
!if(!eq(VT.Size, 48), VReg_64,
VGPR_32))));
}

class getVregSrcForVT_t16<ValueType VT, bit IsFake16 = 1> {
RegisterClass ret = !if(!eq(VT.Size, 128), VReg_128,
!if(!eq(VT.Size, 96), VReg_96,
!if(!eq(VT.Size, 64), VReg_64,
!if(!eq(VT.Size, 48), VReg_64,
!if(!eq(VT.Size, 16),
!if(IsFake16, VGPR_32_Lo128, VGPR_16_Lo128),
VGPR_32)))));

RegisterOperand op = !if (!and(!eq(VT.Size, 16), !not(IsFake16)),
VGPRSrc_16_Lo128, RegisterOperand<ret>);
class getVregSrcForVT<ValueType VT, bit IsTrue16 = 0, bit IsFake16 = 0> {
RegisterOperand ret =
!if (!eq(VT.Size, 128), RegisterOperand<VReg_128>,
!if (!eq(VT.Size, 96), RegisterOperand<VReg_96>,
!if (!eq(VT.Size, 64), RegisterOperand<VReg_64>,
!if (!eq(VT.Size, 48), RegisterOperand<VReg_64>,
!if (!eq(VT.Size, 16),
!if (IsTrue16,
!if (IsFake16, VGPRSrc_32_Lo128, VGPRSrc_16_Lo128),
RegisterOperand<VGPR_32>),
RegisterOperand<VGPR_32>)))));
}

class getSDWASrcForVT <ValueType VT> {
Expand Down Expand Up @@ -1635,13 +1633,13 @@ class getSrcModDPP <ValueType VT> {
Operand ret = !if(VT.isFP, FPVRegInputMods, IntVRegInputMods);
}

class getSrcModDPP_t16 <ValueType VT> {
class getSrcModDPP_t16 <ValueType VT, bit IsFake16 = 1> {
Operand ret =
!if (VT.isFP,
!if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)),
FPT16VRegInputMods, FPVRegInputMods),
!if (!eq(VT.Value, i16.Value), IntT16VRegInputMods,
IntVRegInputMods));
FPT16VRegInputMods<IsFake16>, FPVRegInputMods),
!if (!eq(VT.Value, i16.Value),
IntT16VRegInputMods<IsFake16>, IntVRegInputMods));
}

// Return type of input modifiers operand for specified input operand for DPP
Expand Down Expand Up @@ -1784,10 +1782,9 @@ class getInsVOP3OpSel <RegisterOperand Src0RC, RegisterOperand Src1RC,
Src0Mod, Src1Mod, Src2Mod, /*HasOpSel=*/1>.ret;
}

class getInsDPPBase <RegisterOperand OldRC, RegisterClass Src0RC, RegisterClass Src1RC,
RegisterClass Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld> {

class getInsDPPBase <RegisterOperand OldRC, RegisterOperand Src0RC, RegisterOperand Src1RC,
RegisterOperand Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld> {
dag ret = !if(!eq(NumSrcArgs, 0),
// VOP1 without input operands (V_NOP)
(ins ),
Expand Down Expand Up @@ -1827,26 +1824,26 @@ class getInsDPPBase <RegisterOperand OldRC, RegisterClass Src0RC, RegisterClass
);
}

class getInsDPP <RegisterOperand OldRC, RegisterClass Src0RC, RegisterClass Src1RC,
RegisterClass Src2RC, int NumSrcArgs, bit HasModifiers,
class getInsDPP <RegisterOperand OldRC, RegisterOperand Src0RC, RegisterOperand Src1RC,
RegisterOperand Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld = 1> {
dag ret = !con(getInsDPPBase<OldRC, Src0RC, Src1RC, Src2RC, NumSrcArgs,
HasModifiers, Src0Mod, Src1Mod, Src2Mod, HasOld>.ret,
(ins dpp_ctrl:$dpp_ctrl, row_mask:$row_mask,
bank_mask:$bank_mask, bound_ctrl:$bound_ctrl));
}

class getInsDPP16 <RegisterOperand OldRC, RegisterClass Src0RC, RegisterClass Src1RC,
RegisterClass Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld = 1> {
class getInsDPP16 <RegisterOperand OldRC, RegisterOperand Src0RC, RegisterOperand Src1RC,
RegisterOperand Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld = 1> {
dag ret = !con(getInsDPP<OldRC, Src0RC, Src1RC, Src2RC, NumSrcArgs,
HasModifiers, Src0Mod, Src1Mod, Src2Mod, HasOld>.ret,
(ins FI:$fi));
}

class getInsDPP8 <RegisterOperand OldRC, RegisterClass Src0RC, RegisterClass Src1RC,
RegisterClass Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld = 1> {
class getInsDPP8 <RegisterOperand OldRC, RegisterOperand Src0RC, RegisterOperand Src1RC,
RegisterOperand Src2RC, int NumSrcArgs, bit HasModifiers,
Operand Src0Mod, Operand Src1Mod, Operand Src2Mod, bit HasOld = 1> {
dag ret = !con(getInsDPPBase<OldRC, Src0RC, Src1RC, Src2RC, NumSrcArgs,
HasModifiers, Src0Mod, Src1Mod, Src2Mod, HasOld>.ret,
(ins dpp8:$dpp8, FI:$fi));
Expand Down Expand Up @@ -2251,13 +2248,13 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
field RegisterOperand DstRCVOP3DPP = DstRC64;
field RegisterOperand DstRCSDWA = getSDWADstForVT<DstVT>.ret;
field RegisterOperand Src0RC32 = getVOPSrc0ForVT<Src0VT, IsTrue16>.ret;
field RegisterOperand Src1RC32 = RegisterOperand<getVregSrcForVT<Src1VT>.ret>;
field RegisterOperand Src1RC32 = getVregSrcForVT<Src1VT>.ret;
field RegisterOperand Src0RC64 = getVOP3SrcForVT<Src0VT>.ret;
field RegisterOperand Src1RC64 = getVOP3SrcForVT<Src1VT>.ret;
field RegisterOperand Src2RC64 = getVOP3SrcForVT<Src2VT>.ret;
field RegisterClass Src0DPP = getVregSrcForVT<Src0VT>.ret;
field RegisterClass Src1DPP = getVregSrcForVT<Src1VT>.ret;
field RegisterClass Src2DPP = getVregSrcForVT<Src2VT>.ret;
field RegisterOperand Src0DPP = getVregSrcForVT<Src0VT>.ret;
field RegisterOperand Src1DPP = getVregSrcForVT<Src1VT>.ret;
field RegisterOperand Src2DPP = getVregSrcForVT<Src2VT>.ret;
field RegisterOperand Src0VOP3DPP = VGPRSrc_32;
field RegisterOperand Src1VOP3DPP = getVOP3DPPSrcForVT<Src1VT>.ret;
field RegisterOperand Src2VOP3DPP = getVOP3DPPSrcForVT<Src2VT>.ret;
Expand Down Expand Up @@ -2443,13 +2440,13 @@ class VOPProfile_True16<VOPProfile P> : VOPProfile<P.ArgVT> {
let DstRC = getVALUDstForVT<DstVT, 1 /*IsTrue16*/, 0 /*IsVOP3Encoding*/>.ret;
let DstRC64 = getVALUDstForVT<DstVT>.ret;
let Src0RC32 = getVOPSrc0ForVT<Src0VT, 1 /*IsTrue16*/, 0 /*IsFake16*/>.ret;
let Src1RC32 = getVregSrcForVT_t16<Src1VT, 0 /*IsFake16*/>.op;
let Src0DPP = getVregSrcForVT_t16<Src0VT>.ret;
let Src1DPP = getVregSrcForVT_t16<Src1VT>.ret;
let Src2DPP = getVregSrcForVT_t16<Src2VT>.ret;
let Src0ModDPP = getSrcModDPP_t16<Src0VT>.ret;
let Src1ModDPP = getSrcModDPP_t16<Src1VT>.ret;
let Src2ModDPP = getSrcModDPP_t16<Src2VT>.ret;
let Src1RC32 = getVregSrcForVT<Src1VT, 1 /*IsTrue16*/, 0 /*IsFake16*/>.ret;
let Src0DPP = getVregSrcForVT<Src0VT, 1 /*IsTrue16*/, 0 /*IsFake16*/>.ret;
let Src1DPP = getVregSrcForVT<Src1VT, 1 /*IsTrue16*/, 0 /*IsFake16*/>.ret;
let Src2DPP = getVregSrcForVT<Src2VT, 1 /*IsTrue16*/, 0 /*IsFake16*/>.ret;
let Src0ModDPP = getSrcModDPP_t16<Src0VT, 0 /*IsFake16*/>.ret;
let Src1ModDPP = getSrcModDPP_t16<Src1VT, 0 /*IsFake16*/>.ret;
let Src2ModDPP = getSrcModDPP_t16<Src2VT, 0 /*IsFake16*/>.ret;

let DstRC64 = getVALUDstForVT<DstVT, 1 /*IsTrue16*/, 1 /*IsVOP3Encoding*/>.ret;
let Src0RC64 = getVOP3SrcForVT<Src0VT, 1 /*IsTrue16*/>.ret;
Expand All @@ -2465,10 +2462,10 @@ class VOPProfile_Fake16<VOPProfile P> : VOPProfile<P.ArgVT> {
// Most DstVT are 16-bit, but not all
let DstRC = getVALUDstForVT_fake16<DstVT>.ret;
let DstRC64 = getVALUDstForVT<DstVT>.ret;
let Src1RC32 = RegisterOperand<getVregSrcForVT_t16<Src1VT>.ret>;
let Src0DPP = getVregSrcForVT_t16<Src0VT>.ret;
let Src1DPP = getVregSrcForVT_t16<Src1VT>.ret;
let Src2DPP = getVregSrcForVT_t16<Src2VT>.ret;
let Src1RC32 = getVregSrcForVT<Src1VT, 1/*IsTrue16*/, 1/*IsFake16*/>.ret;
let Src0DPP = getVregSrcForVT<Src0VT, 1/*IsTrue16*/, 1/*IsFake16*/>.ret;
let Src1DPP = getVregSrcForVT<Src1VT, 1/*IsTrue16*/, 1/*IsFake16*/>.ret;
let Src2DPP = getVregSrcForVT<Src2VT, 1/*IsTrue16*/, 1/*IsFake16*/>.ret;
let Src0ModDPP = getSrcModDPP_t16<Src0VT>.ret;
let Src1ModDPP = getSrcModDPP_t16<Src1VT>.ret;
let Src2ModDPP = getSrcModDPP_t16<Src2VT>.ret;
Expand Down
Loading