Skip to content

[PowerPC] Add dense math bfloat16 floating-point outer-product accumulate to DMR instructions #133109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsPowerPC.td
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ multiclass PowerPC_MMA_ACC_PP_Intrinsic<list<LLVMType> args> {
[IntrNoMem]>;
}

multiclass PowerPC_MMA_DMR_Intrinsic<list<LLVMType> args> {
def NAME: DefaultAttrsIntrinsic<[llvm_v1024i1_ty], args, [IntrNoMem]>;
def pp : DefaultAttrsIntrinsic<[llvm_v1024i1_ty],
!listconcat([llvm_v1024i1_ty], args),
[IntrNoMem]>;
def pn : DefaultAttrsIntrinsic<[llvm_v1024i1_ty],
!listconcat([llvm_v1024i1_ty], args),
[IntrNoMem]>;
def np : DefaultAttrsIntrinsic<[llvm_v1024i1_ty],
!listconcat([llvm_v1024i1_ty], args),
[IntrNoMem]>;
def nn : DefaultAttrsIntrinsic<[llvm_v1024i1_ty],
!listconcat([llvm_v1024i1_ty], args),
[IntrNoMem]>;
}

multiclass PowerPC_MMA_DMR_PP_Intrinsic<list<LLVMType> args> {
def NAME: DefaultAttrsIntrinsic<[llvm_v1024i1_ty], args, [IntrNoMem]>;
def pp : DefaultAttrsIntrinsic<[llvm_v1024i1_ty],
Expand Down Expand Up @@ -1732,6 +1748,13 @@ let TargetPrefix = "ppc" in {
[llvm_v1024i1_ty, llvm_v256i1_ty, llvm_v16i8_ty,
llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem]>;

// MMA+ Reduced-Precision: bfloat16 Outer Product Intrinsic Definitions.
defm int_ppc_mma_dmxvbf16gerx2 :
PowerPC_MMA_DMR_Intrinsic<[llvm_v256i1_ty, llvm_v16i8_ty]>;
defm int_ppc_mma_pmdmxvbf16gerx2 :
PowerPC_MMA_DMR_Intrinsic<[llvm_v256i1_ty, llvm_v16i8_ty, llvm_i32_ty,
llvm_i32_ty, llvm_i32_ty]>;
}

// XL Compat intrinsics.
Expand Down
162 changes: 161 additions & 1 deletion llvm/lib/Target/PowerPC/PPCInstrFutureMMA.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class MMIRR_XX3Form_X8YP4_XAp5B6<bits<6> opcode, bits<8> xo, dag OOL, dag IOL,
list<dag> pattern>
: PI<1, opcode, OOL, IOL, asmstr, itin> {
bits<3> AT;
bits<6> XAp;
bits<5> XAp;
bits<6> XB;
bits<8> XMSK;
bits<4> YMSK;
Expand Down Expand Up @@ -123,6 +123,40 @@ class MMIRR_XX3Form_X8YP4_XAp5B6<bits<6> opcode, bits<8> xo, dag OOL, dag IOL,
let Inst{63} = 0;
}

class MMIRR_XX3Form_X8Y4P2_XAp5B6<bits<6> opcode, bits<8> xo, dag OOL, dag IOL,
string asmstr, InstrItinClass itin,
list<dag> pattern>
: PI<1, opcode, OOL, IOL, asmstr, itin> {
bits<3> AT;
bits<5> XAp;
bits<6> XB;
bits<8> XMSK;
bits<4> YMSK;
bits<2> PMSK;

let Pattern = pattern;

// The prefix.
let Inst{6-7} = 3;
let Inst{8-11} = 9;
let Inst{12-15} = 0;
let Inst{16-17} = PMSK;
let Inst{18-19} = 0;
let Inst{20-27} = XMSK;
let Inst{28-31} = YMSK;

// The instruction.
let Inst{38-40} = AT;
let Inst{41-42} = 0;
let Inst{43-46} = XAp{3-0};
let Inst{47} = 0;
let Inst{48-52} = XB{4-0};
let Inst{53-60} = xo;
let Inst{61} = XAp{4};
let Inst{62} = XB{5};
let Inst{63} = 0;
}

multiclass DMR_UM_XOEO<bits<6> opcode, bits<8> xo, dag IOL, string asmbase,
string asmstr> {
let Predicates = [MMA, IsISAFuture] in {
Expand Down Expand Up @@ -159,6 +193,83 @@ multiclass DMR_UM_M448_XOEO<bits<6> opcode, bits<8> xo, dag IOL, string asmbase,
}
}

multiclass DMR_BF16_UM_XOEO<bits<6> opcode, bits<8> xo, dag IOL, string asmbase,
string asmstr> {
let Predicates = [MMA, IsISAFuture] in {
def NAME :
XX3Form_AT3_XAp5B6<opcode, !or(xo, 0x11), (outs dmr:$AT), IOL,
!strconcat(asmbase#" ", asmstr), IIC_VecFP, []>,
RegConstraint<"@earlyclobber $AT">;
def PP :
XX3Form_AT3_XAp5B6<opcode, xo, (outs dmr:$AT), !con((ins dmr:$ATi), IOL),
!strconcat(asmbase#"pp ", asmstr), IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}
}

multiclass DMR_UM_M284_XOEO<bits<6> opcode, bits<8> xo, dag IOL, string asmbase,
string asmstr> {
defm NAME : DMR_BF16_UM_XOEO<opcode, xo, IOL, asmbase, asmstr>;
let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
def PM#NAME :
MMIRR_XX3Form_X8Y4P2_XAp5B6<
opcode, !or(xo, 0x11), (outs dmr:$AT),
!con(IOL, (ins u8imm:$XMSK, u4imm:$YMSK, u2imm:$PMSK)),
!strconcat("pm"#asmbase#" ", asmstr#", $XMSK, $YMSK, $PMSK"),
IIC_VecFP, []>,
RegConstraint<"@earlyclobber $AT">;
def PM#NAME#PP :
MMIRR_XX3Form_X8Y4P2_XAp5B6<
opcode, xo, (outs dmr:$AT),
!con((ins dmr:$ATi), !con(IOL, (ins u8imm:$XMSK, u4imm:$YMSK, u2imm:$PMSK))),
!strconcat("pm"#asmbase#"pp ", asmstr#", $XMSK, $YMSK, $PMSK"),
IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}
}

multiclass DMR_NEG_UM_M284_XOXORf939a0<bits<6> opcode, bits<8> xo, dag IOL,
string asmbase, string asmstr> {
defm NAME : DMR_UM_M284_XOEO<opcode, xo, IOL, asmbase, asmstr>;
let Predicates = [MMA, IsISAFuture] in {
def PN : XX3Form_AT3_XAp5B6<
opcode, !xor(xo, 0xF9), (outs dmr:$AT), !con((ins dmr:$ATi), IOL),
!strconcat(asmbase#"pn ", asmstr), IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
def NP : XX3Form_AT3_XAp5B6<
opcode, !xor(xo, 0x39), (outs dmr:$AT), !con((ins dmr:$ATi), IOL),
!strconcat(asmbase#"np ", asmstr), IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
def NN : XX3Form_AT3_XAp5B6<
opcode, !xor(xo, 0xA0), (outs dmr:$AT), !con((ins dmr:$ATi), IOL),
!strconcat(asmbase#"nn ", asmstr), IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}
let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
def PM#NAME#PN :
MMIRR_XX3Form_X8Y4P2_XAp5B6<
opcode, !xor(xo, 0xF9), (outs dmr:$AT),
!con((ins dmr:$ATi), !con(IOL, (ins u8imm:$XMSK, u4imm:$YMSK, u2imm:$PMSK))),
!strconcat("pm"#asmbase#"pn ", asmstr#", $XMSK, $YMSK, $PMSK"),
IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
def PM#NAME#NP :
MMIRR_XX3Form_X8Y4P2_XAp5B6<
opcode, !xor(xo, 0x39), (outs dmr:$AT),
!con((ins dmr:$ATi), !con(IOL, (ins u8imm:$XMSK, u4imm:$YMSK, u2imm:$PMSK))),
!strconcat("pm"#asmbase#"np ", asmstr#", $XMSK, $YMSK, $PMSK"),
IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
def PM#NAME#NN :
MMIRR_XX3Form_X8Y4P2_XAp5B6<
opcode, !xor(xo, 0xA0), (outs dmr:$AT),
!con((ins dmr:$ATi), !con(IOL, (ins u8imm:$XMSK, u4imm:$YMSK, u2imm:$PMSK))),
!strconcat("pm"#asmbase#"nn ", asmstr#", $XMSK, $YMSK, $PMSK"),
IIC_VecFP, []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}
}

let Predicates = [IsISAFuture] in {
def DMXXEXTFDMR512 : XX3Form_AT3_XABp5_P1<60, 226,
(outs vsrprc:$XAp, vsrprc:$XBp),
Expand Down Expand Up @@ -231,6 +342,11 @@ let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}

// DMXVBF16GERX2, DMXVBF16GERX2PP, DMXVBF16GERX2PN, dMXVBF16GERX2NP, DMXVBF16GERX2NN
// PMDMXVBF16GERX2, PMDMXVBF16GERX2PP, PMDMXVBF16GERX2PN, PMDMXVBF16GERX2NP, PMDMXVBF16GERX2NN
defm DMXVBF16GERX2 : DMR_NEG_UM_M284_XOXORf939a0<59, 74, (ins vsrprc:$XAp, vsrc:$XB),
"dmxvbf16gerx2", "$AT, $XAp, $XB">;

// MMA+ Intrinsics
let Predicates = [MMA, IsISAFuture] in {
def : Pat<(v1024i1 (int_ppc_mma_dmxvi8gerx4 v256i1:$XAp, v16i8:$XB)),
Expand All @@ -240,6 +356,21 @@ let Predicates = [MMA, IsISAFuture] in {

def : Pat<(v1024i1 (int_ppc_mma_dmxvi8gerx4spp v1024i1:$ATi, v256i1:$XAp, v16i8:$XB)),
(DMXVI8GERX4SPP $ATi, $XAp, RCCp.BToVSRC)>;

def : Pat<(v1024i1 (int_ppc_mma_dmxvbf16gerx2 v256i1:$XAp, v16i8:$XB)),
(DMXVBF16GERX2 $XAp, RCCp.BToVSRC)>;

def : Pat<(v1024i1 (int_ppc_mma_dmxvbf16gerx2pp v1024i1:$ATi, v256i1:$XAp, v16i8:$XB)),
(DMXVBF16GERX2PP $ATi, $XAp, RCCp.BToVSRC)>;

def : Pat<(v1024i1 (int_ppc_mma_dmxvbf16gerx2pn v1024i1:$ATi, v256i1:$XAp, v16i8:$XB)),
(DMXVBF16GERX2PN $ATi, $XAp, RCCp.BToVSRC)>;

def : Pat<(v1024i1 (int_ppc_mma_dmxvbf16gerx2np v1024i1:$ATi, v256i1:$XAp, v16i8:$XB)),
(DMXVBF16GERX2NP $ATi, $XAp, RCCp.BToVSRC)>;

def : Pat<(v1024i1 (int_ppc_mma_dmxvbf16gerx2nn v1024i1:$ATi, v256i1:$XAp, v16i8:$XB)),
(DMXVBF16GERX2NN $ATi, $XAp, RCCp.BToVSRC)>;
}

let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
Expand All @@ -259,4 +390,33 @@ let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
Msk4Imm:$PMSK)),
(PMDMXVI8GERX4SPP $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk4Imm:$PMSK)>;

def : Pat<(v1024i1 (int_ppc_mma_pmdmxvbf16gerx2 v256i1:$XAp, v16i8:$XB, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)),
(PMDMXVBF16GERX2 $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;

def : Pat<(v1024i1 (int_ppc_mma_pmdmxvbf16gerx2pp v1024i1:$ATi, v256i1:$XAp, v16i8:$XB,
Msk8Imm:$XMSK, Msk4Imm:$YMSK,
Msk2Imm:$PMSK)),
(PMDMXVBF16GERX2PP $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;

def : Pat<(v1024i1 (int_ppc_mma_pmdmxvbf16gerx2pn v1024i1:$ATi, v256i1:$XAp, v16i8:$XB,
Msk8Imm:$XMSK, Msk4Imm:$YMSK,
Msk2Imm:$PMSK)),
(PMDMXVBF16GERX2PN $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;

def : Pat<(v1024i1 (int_ppc_mma_pmdmxvbf16gerx2np v1024i1:$ATi, v256i1:$XAp, v16i8:$XB,
Msk8Imm:$XMSK, Msk4Imm:$YMSK,
Msk2Imm:$PMSK)),
(PMDMXVBF16GERX2NP $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;

def : Pat<(v1024i1 (int_ppc_mma_pmdmxvbf16gerx2nn v1024i1:$ATi, v256i1:$XAp, v16i8:$XB,
Msk8Imm:$XMSK, Msk4Imm:$YMSK,
Msk2Imm:$PMSK)),
(PMDMXVBF16GERX2NN $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;
}
Loading
Loading