Skip to content

AMDGPU/GFX12: Add new dot4 fp8/bf8 instructions #77892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/BuiltinsAMDGPU.def
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ TARGET_BUILTIN(__builtin_amdgcn_sudot4, "iIbiIbiiIb", "nc", "dot8-insts")
TARGET_BUILTIN(__builtin_amdgcn_sdot8, "SiSiSiSiIb", "nc", "dot1-insts")
TARGET_BUILTIN(__builtin_amdgcn_udot8, "UiUiUiUiIb", "nc", "dot7-insts")
TARGET_BUILTIN(__builtin_amdgcn_sudot8, "iIbiIbiiIb", "nc", "dot8-insts")
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_fp8_bf8, "fUiUif", "nc", "gfx12-insts")
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_bf8_fp8, "fUiUif", "nc", "gfx12-insts")
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_fp8_fp8, "fUiUif", "nc", "gfx12-insts")
TARGET_BUILTIN(__builtin_amdgcn_dot4_f32_bf8_bf8, "fUiUif", "nc", "gfx12-insts")

//===----------------------------------------------------------------------===//
// GFX10+ only builtins.
Expand Down
5 changes: 5 additions & 0 deletions clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-err.cl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,9 @@ kernel void builtins_amdgcn_dl_insts_err(

iOut[3] = __builtin_amdgcn_sudot8(false, A, true, B, C, false); // expected-error {{'__builtin_amdgcn_sudot8' needs target feature dot8-insts}}
iOut[4] = __builtin_amdgcn_sudot8(true, A, false, B, C, true); // expected-error {{'__builtin_amdgcn_sudot8' needs target feature dot8-insts}}

fOut[5] = __builtin_amdgcn_dot4_f32_fp8_bf8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_fp8_bf8' needs target feature gfx12-insts}}
fOut[6] = __builtin_amdgcn_dot4_f32_bf8_fp8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_bf8_fp8' needs target feature gfx12-insts}}
fOut[7] = __builtin_amdgcn_dot4_f32_fp8_fp8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_fp8_fp8' needs target feature gfx12-insts}}
fOut[8] = __builtin_amdgcn_dot4_f32_bf8_bf8(uiA, uiB, fC); // expected-error {{'__builtin_amdgcn_dot4_f32_bf8_bf8' needs target feature gfx12-insts}}
}
20 changes: 20 additions & 0 deletions clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx12.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// REQUIRES: amdgpu-registered-target

// RUN: %clang_cc1 -triple amdgcn-unknown-unknown -target-cpu gfx1200 -S -emit-llvm -o - %s | FileCheck %s

typedef unsigned int uint;

// CHECK-LABEL: @builtins_amdgcn_dl_insts
// CHECK: call float @llvm.amdgcn.dot4.f32.fp8.bf8(i32 %uiA, i32 %uiB, float %fC)
// CHECK: call float @llvm.amdgcn.dot4.f32.bf8.fp8(i32 %uiA, i32 %uiB, float %fC)
// CHECK: call float @llvm.amdgcn.dot4.f32.fp8.fp8(i32 %uiA, i32 %uiB, float %fC)
// CHECK: call float @llvm.amdgcn.dot4.f32.bf8.bf8(i32 %uiA, i32 %uiB, float %fC)

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void builtins_amdgcn_dl_insts_err(global float *fOut,
uint uiA, uint uiB, float fC) {
fOut[0] = __builtin_amdgcn_dot4_f32_fp8_bf8(uiA, uiB, fC);
fOut[1] = __builtin_amdgcn_dot4_f32_bf8_fp8(uiA, uiB, fC);
fOut[2] = __builtin_amdgcn_dot4_f32_fp8_fp8(uiA, uiB, fC);
fOut[3] = __builtin_amdgcn_dot4_f32_bf8_bf8(uiA, uiB, fC);
}
19 changes: 19 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -2727,6 +2727,25 @@ def int_amdgcn_udot8 :
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<5>>]
>;

// f32 %r = llvm.amdgcn.dot4.f32.type_a.type_b (v4type_a (as i32) %a, v4type_b (as i32) %b, f32 %c)
// %r = %a[0] * %b[0] + %a[1] * %b[1] + %a[2] * %b[2] + %a[3] * %b[3] + %c
class AMDGPU8bitFloatDot4Intrinsic :
ClangBuiltin<!subst("int", "__builtin", NAME)>,
DefaultAttrsIntrinsic<
[llvm_float_ty], // %r
[
llvm_i32_ty, // %a
llvm_i32_ty, // %b
llvm_float_ty, // %c
],
[IntrNoMem, IntrSpeculatable]
>;

def int_amdgcn_dot4_f32_fp8_bf8 : AMDGPU8bitFloatDot4Intrinsic;
def int_amdgcn_dot4_f32_bf8_fp8 : AMDGPU8bitFloatDot4Intrinsic;
def int_amdgcn_dot4_f32_fp8_fp8 : AMDGPU8bitFloatDot4Intrinsic;
def int_amdgcn_dot4_f32_bf8_bf8 : AMDGPU8bitFloatDot4Intrinsic;

//===----------------------------------------------------------------------===//
// gfx908 intrinsics
// ===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4491,6 +4491,10 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case Intrinsic::amdgcn_fdot2_f32_bf16:
case Intrinsic::amdgcn_sudot4:
case Intrinsic::amdgcn_sudot8:
case Intrinsic::amdgcn_dot4_f32_fp8_bf8:
case Intrinsic::amdgcn_dot4_f32_bf8_fp8:
case Intrinsic::amdgcn_dot4_f32_fp8_fp8:
case Intrinsic::amdgcn_dot4_f32_bf8_bf8:
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16:
case Intrinsic::amdgcn_wmma_f16_16x16x16_f16:
case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied:
Expand Down
46 changes: 46 additions & 0 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
bool validateMIMGD16(const MCInst &Inst);
bool validateMIMGMSAA(const MCInst &Inst);
bool validateOpSel(const MCInst &Inst);
bool validateNeg(const MCInst &Inst, int OpName);
bool validateDPP(const MCInst &Inst, const OperandVector &Operands);
bool validateVccOperand(unsigned Reg) const;
bool validateVOPLiteral(const MCInst &Inst, const OperandVector &Operands);
Expand Down Expand Up @@ -4357,6 +4358,41 @@ bool AMDGPUAsmParser::validateOpSel(const MCInst &Inst) {
return true;
}

bool AMDGPUAsmParser::validateNeg(const MCInst &Inst, int OpName) {
assert(OpName == AMDGPU::OpName::neg_lo || OpName == AMDGPU::OpName::neg_hi);

const unsigned Opc = Inst.getOpcode();
uint64_t TSFlags = MII.get(Opc).TSFlags;

// v_dot4 fp8/bf8 neg_lo/neg_hi not allowed on src0 and src1 (allowed on src2)
if (!(TSFlags & SIInstrFlags::IsDOT))
return true;

int NegIdx = AMDGPU::getNamedOperandIdx(Opc, OpName);
if (NegIdx == -1)
return true;

unsigned Neg = Inst.getOperand(NegIdx).getImm();

// Instructions that have neg_lo or neg_hi operand but neg modifier is allowed
// on some src operands but not allowed on other.
// It is convenient that such instructions don't have src_modifiers operand
// for src operands that don't allow neg because they also don't allow opsel.

int SrcMods[3] = {AMDGPU::OpName::src0_modifiers,
AMDGPU::OpName::src1_modifiers,
AMDGPU::OpName::src2_modifiers};

for (unsigned i = 0; i < 3; ++i) {
if (!AMDGPU::hasNamedOperand(Opc, SrcMods[i])) {
if (Neg & (1 << i))
return false;
}
}

return true;
}

bool AMDGPUAsmParser::validateDPP(const MCInst &Inst,
const OperandVector &Operands) {
const unsigned Opc = Inst.getOpcode();
Expand Down Expand Up @@ -4834,6 +4870,16 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
"invalid op_sel operand");
return false;
}
if (!validateNeg(Inst, AMDGPU::OpName::neg_lo)) {
Error(getImmLoc(AMDGPUOperand::ImmTyNegLo, Operands),
"invalid neg_lo operand");
return false;
}
if (!validateNeg(Inst, AMDGPU::OpName::neg_hi)) {
Error(getImmLoc(AMDGPUOperand::ImmTyNegHi, Operands),
"invalid neg_hi operand");
return false;
}
if (!validateDPP(Inst, Operands)) {
return false;
}
Expand Down
17 changes: 11 additions & 6 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1260,14 +1260,19 @@ void AMDGPUInstPrinter::printPackedModifier(const MCInst *MI,
int NumOps = 0;
int Ops[3];

for (int OpName : { AMDGPU::OpName::src0_modifiers,
AMDGPU::OpName::src1_modifiers,
AMDGPU::OpName::src2_modifiers }) {
int Idx = AMDGPU::getNamedOperandIdx(Opc, OpName);
if (Idx == -1)
std::pair<int, int> MOps[] = {
{AMDGPU::OpName::src0_modifiers, AMDGPU::OpName::src0},
{AMDGPU::OpName::src1_modifiers, AMDGPU::OpName::src1},
{AMDGPU::OpName::src2_modifiers, AMDGPU::OpName::src2}};
int DefaultValue = (Mod == SISrcMods::OP_SEL_1);

for (auto [SrcMod, Src] : MOps) {
if (!AMDGPU::hasNamedOperand(Opc, Src))
break;

Ops[NumOps++] = MI->getOperand(Idx).getImm();
int ModIdx = AMDGPU::getNamedOperandIdx(Opc, SrcMod);
Ops[NumOps++] =
(ModIdx != -1) ? MI->getOperand(ModIdx).getImm() : DefaultValue;
}

const bool HasDstSel =
Expand Down
47 changes: 47 additions & 0 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,48 @@ def : GCNPat < (int_amdgcn_sdot4 i32:$src0,
>;
} // End SubtargetPredicate = HasDot8Insts

// Does not use opsel, no src_modifiers on src0 and src1.
// src_modifiers on src2(f32) are f32 fneg(neg_lo[2]) and f32 fabs(neg_hi[2]).
def VOP3P_DOTF8_Profile : VOP3P_Profile<VOPProfile <[f32, i32, i32, f32]>,
VOP3_PACKED, 1> {
let HasClamp = 0;
let HasOpSel = 0;
let HasOMod = 0;
let IsDOT = 1;
let HasSrc0Mods = 0;
let HasSrc1Mods = 0;
let HasSrc2Mods = 1;

let InsVOP3P = (ins VSrc_b32:$src0, VSrc_b32:$src1,
PackedF16InputMods:$src2_modifiers, VSrc_f32:$src2,
neg_lo0:$neg_lo, neg_hi0:$neg_hi);

let InsVOP3DPP8 = (ins DstRC:$old, VGPR_32:$src0, VRegSrc_32:$src1,
PackedF16InputMods:$src2_modifiers, VRegSrc_32:$src2,
neg_lo0:$neg_lo, neg_hi0:$neg_hi, dpp8:$dpp8, FI:$fi);

let InsVOP3DPP16 = (ins DstRC:$old, VGPR_32:$src0, VRegSrc_32:$src1,
PackedF16InputMods:$src2_modifiers, VRegSrc_32:$src2,
neg_lo0:$neg_lo, neg_hi0:$neg_hi, dpp_ctrl:$dpp_ctrl,
row_mask:$row_mask, bank_mask:$bank_mask,
bound_ctrl:$bound_ctrl, FI:$fi);
}

multiclass VOP3PDOTF8Inst <string OpName, SDPatternOperator intrinsic_node> {
defm NAME : VOP3PInst<OpName, VOP3P_DOTF8_Profile, null_frag, 1>;

let SubtargetPredicate = isGFX12Plus in
def : GCNPat <(intrinsic_node i32:$src0, i32:$src1,
(VOP3Mods f32:$src2, i32:$src2_modifiers)),
(!cast<Instruction>(NAME) i32:$src0, i32:$src1,
i32:$src2_modifiers, f32:$src2)>;
}

defm V_DOT4_F32_FP8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_fp8_bf8", int_amdgcn_dot4_f32_fp8_bf8>;
defm V_DOT4_F32_BF8_FP8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_fp8", int_amdgcn_dot4_f32_bf8_fp8>;
defm V_DOT4_F32_FP8_FP8 : VOP3PDOTF8Inst<"v_dot4_f32_fp8_fp8", int_amdgcn_dot4_f32_fp8_fp8>;
defm V_DOT4_F32_BF8_BF8 : VOP3PDOTF8Inst<"v_dot4_f32_bf8_bf8", int_amdgcn_dot4_f32_bf8_bf8>;

def : UDot2Pat<V_DOT2_U32_U16>;
def : SDot2Pat<V_DOT2_I32_I16>;

Expand Down Expand Up @@ -1019,6 +1061,11 @@ defm V_PK_MAX_NUM_F16 : VOP3P_Real_with_name_gfx12<0x1c, "V_PK_MAX_F16", "v_pk_m
defm V_PK_MINIMUM_F16 : VOP3P_Real_gfx12<0x1d>;
defm V_PK_MAXIMUM_F16 : VOP3P_Real_gfx12<0x1e>;

defm V_DOT4_F32_FP8_BF8 : VOP3P_Realtriple<GFX12Gen, 0x24>;
defm V_DOT4_F32_BF8_FP8 : VOP3P_Realtriple<GFX12Gen, 0x25>;
defm V_DOT4_F32_FP8_FP8 : VOP3P_Realtriple<GFX12Gen, 0x26>;
defm V_DOT4_F32_BF8_BF8 : VOP3P_Realtriple<GFX12Gen, 0x27>;

//===----------------------------------------------------------------------===//
// GFX11
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 7 additions & 6 deletions llvm/lib/Target/AMDGPU/VOPInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class VOP3_Pseudo <string opName, VOPProfile P, list<dag> pattern = [],
class VOP3P_Pseudo <string opName, VOPProfile P, list<dag> pattern = []> :
VOP3_Pseudo<opName, P, pattern, 1> {
let VOP3P = 1;
let IsDOT = P.IsDOT;
}

class VOP_Real<VOP_Pseudo ps> {
Expand Down Expand Up @@ -387,7 +388,7 @@ class VOP3Pe <bits<7> op, VOPProfile P> : Enc64 {
let Inst{12} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{2}, 0); // op_sel(1)
let Inst{13} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{2}, 0); // op_sel(2)

let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, ?); // op_sel_hi(2)
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(2)

let Inst{15} = !if(P.HasClamp, clamp{0}, 0);

Expand All @@ -396,8 +397,8 @@ class VOP3Pe <bits<7> op, VOPProfile P> : Enc64 {
let Inst{40-32} = !if(P.HasSrc0, src0, 0);
let Inst{49-41} = !if(P.HasSrc1, src1, 0);
let Inst{58-50} = !if(P.HasSrc2, src2, 0);
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, ?); // op_sel_hi(0)
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, ?); // op_sel_hi(1)
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(0)
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(1)
let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, 0); // neg (lo)
let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, 0); // neg (lo)
let Inst{63} = !if(P.HasSrc2Mods, src2_modifiers{0}, 0); // neg (lo)
Expand Down Expand Up @@ -772,12 +773,12 @@ class VOP3P_DPPe_Common_Base<bits<7> op, VOPProfile P> : Enc96 {
let Inst{11} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{2}, 0); // op_sel(0)
let Inst{12} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{2}, 0); // op_sel(1)
let Inst{13} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{2}, 0); // op_sel(2)
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, ?); // op_sel_hi(2)
let Inst{14} = !if(!and(P.HasSrc2, P.HasOpSel), src2_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(2)
let Inst{15} = !if(P.HasClamp, clamp{0}, 0);
let Inst{22-16} = op;
let Inst{31-23} = 0x198; // encoding
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, ?); // op_sel_hi(0)
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, ?); // op_sel_hi(1)
let Inst{59} = !if(!and(P.HasSrc0, P.HasOpSel), src0_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(0)
let Inst{60} = !if(!and(P.HasSrc1, P.HasOpSel), src1_modifiers{3}, !if(P.IsDOT, 1, ?)); // op_sel_hi(1)
let Inst{61} = !if(P.HasSrc0Mods, src0_modifiers{0}, 0); // neg (lo)
let Inst{62} = !if(P.HasSrc1Mods, src1_modifiers{0}, 0); // neg (lo)
let Inst{63} = !if(P.HasSrc2Mods, src2_modifiers{0}, 0); // neg (lo)
Expand Down
Loading