Skip to content

[PowerPC] Add load/store support for v2048i1 and DMF cryptography instructions #136145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def llvm_v128i1_ty : LLVMType<v128i1>; // 128 x i1
def llvm_v256i1_ty : LLVMType<v256i1>; // 256 x i1
def llvm_v512i1_ty : LLVMType<v512i1>; // 512 x i1
def llvm_v1024i1_ty : LLVMType<v1024i1>; //1024 x i1
def llvm_v2048i1_ty : LLVMType<v2048i1>; //2048 x i1
def llvm_v4096i1_ty : LLVMType<v4096i1>; //4096 x i1

def llvm_v1i8_ty : LLVMType<v1i8>; // 1 x i8
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsPowerPC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,20 @@ let TargetPrefix = "ppc" in {
defm int_ppc_mma_pmdmxvf16gerx2 :
PowerPC_MMA_DMR_Intrinsic<[llvm_v256i1_ty, llvm_v16i8_ty, llvm_i32_ty,
llvm_i32_ty, llvm_i32_ty]>;
def int_ppc_mma_dmsha2hash :
DefaultAttrsIntrinsic<[llvm_v1024i1_ty], [llvm_v1024i1_ty,
llvm_v1024i1_ty, llvm_i32_ty],
[IntrNoMem, ImmArg<ArgIndex<2>>]>;

def int_ppc_mma_dmsha3hash :
DefaultAttrsIntrinsic<[llvm_v2048i1_ty], [llvm_v2048i1_ty,
llvm_i32_ty], [IntrNoMem, ImmArg<ArgIndex<1>>]>;

def int_ppc_mma_dmxxshapad :
DefaultAttrsIntrinsic<[llvm_v1024i1_ty], [llvm_v1024i1_ty,
llvm_v16i8_ty, llvm_i32_ty, llvm_i32_ty,
llvm_i32_ty], [IntrNoMem, ImmArg<ArgIndex<2>>,
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>]>;
}

// XL Compat intrinsics.
Expand Down
144 changes: 115 additions & 29 deletions llvm/lib/Target/PowerPC/PPCISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1361,8 +1361,11 @@ PPCTargetLowering::PPCTargetLowering(const PPCTargetMachine &TM,
if (Subtarget.isISAFuture()) {
addRegisterClass(MVT::v512i1, &PPC::WACCRCRegClass);
addRegisterClass(MVT::v1024i1, &PPC::DMRRCRegClass);
addRegisterClass(MVT::v2048i1, &PPC::DMRpRCRegClass);
setOperationAction(ISD::LOAD, MVT::v1024i1, Custom);
setOperationAction(ISD::STORE, MVT::v1024i1, Custom);
setOperationAction(ISD::LOAD, MVT::v2048i1, Custom);
setOperationAction(ISD::STORE, MVT::v2048i1, Custom);
} else {
addRegisterClass(MVT::v512i1, &PPC::UACCRCRegClass);
}
Expand Down Expand Up @@ -11890,15 +11893,19 @@ SDValue PPCTargetLowering::LowerDMFVectorLoad(SDValue Op,
SDValue LoadChain = LN->getChain();
SDValue BasePtr = LN->getBasePtr();
EVT VT = Op.getValueType();
bool IsV1024i1 = VT == MVT::v1024i1;
bool IsV2048i1 = VT == MVT::v2048i1;

// Type v1024i1 is used for Dense Math dmr registers.
assert(VT == MVT::v1024i1 && "Unsupported type.");
// The types v1024i1 and v2048i1 are used for Dense Math dmr registers and
// Dense Math dmr pair registers, respectively.
assert((IsV1024i1 || IsV2048i1) && "Unsupported type.");
assert((Subtarget.hasMMA() && Subtarget.isISAFuture()) &&
"Dense Math support required.");
assert(Subtarget.pairedVectorMemops() && "Vector pair support required.");

SmallVector<SDValue, 4> Loads;
SmallVector<SDValue, 4> LoadChains;
SmallVector<SDValue, 8> Loads;
SmallVector<SDValue, 8> LoadChains;

SDValue IntrinID = DAG.getConstant(Intrinsic::ppc_vsx_lxvp, dl, MVT::i32);
SDValue LoadOps[] = {LoadChain, IntrinID, BasePtr};
MachineMemOperand *MMO = LN->getMemOperand();
Expand Down Expand Up @@ -11934,11 +11941,36 @@ SDValue PPCTargetLowering::LowerDMFVectorLoad(SDValue Op,
SDValue HiSub = DAG.getTargetConstant(PPC::sub_wacc_hi, dl, MVT::i32);
SDValue RC = DAG.getTargetConstant(PPC::DMRRCRegClassID, dl, MVT::i32);
const SDValue Ops[] = {RC, Lo, LoSub, Hi, HiSub};

SDValue Value =
SDValue(DAG.getMachineNode(PPC::REG_SEQUENCE, dl, MVT::v1024i1, Ops), 0);

SDValue RetOps[] = {Value, TF};
return DAG.getMergeValues(RetOps, dl);
if (IsV1024i1) {
return DAG.getMergeValues({Value, TF}, dl);
}

// Handle Loads for V2048i1 which represents a dmr pair.
SDValue DmrPValue;
SDValue Dmr1Lo(DAG.getMachineNode(PPC::DMXXINSTDMR512, dl, MVT::v512i1,
Loads[4], Loads[5]),
0);
SDValue Dmr1Hi(DAG.getMachineNode(PPC::DMXXINSTDMR512_HI, dl, MVT::v512i1,
Loads[6], Loads[7]),
0);
const SDValue Dmr1Ops[] = {RC, Dmr1Lo, LoSub, Dmr1Hi, HiSub};
SDValue Dmr1Value = SDValue(
DAG.getMachineNode(PPC::REG_SEQUENCE, dl, MVT::v1024i1, Dmr1Ops), 0);

SDValue Dmr0Sub = DAG.getTargetConstant(PPC::sub_dmr0, dl, MVT::i32);
SDValue Dmr1Sub = DAG.getTargetConstant(PPC::sub_dmr1, dl, MVT::i32);

SDValue DmrPRC = DAG.getTargetConstant(PPC::DMRpRCRegClassID, dl, MVT::i32);
const SDValue DmrPOps[] = {DmrPRC, Value, Dmr0Sub, Dmr1Value, Dmr1Sub};

DmrPValue = SDValue(
DAG.getMachineNode(PPC::REG_SEQUENCE, dl, MVT::v2048i1, DmrPOps), 0);

return DAG.getMergeValues({DmrPValue, TF}, dl);
}

SDValue PPCTargetLowering::LowerVectorLoad(SDValue Op,
Expand All @@ -11949,7 +11981,7 @@ SDValue PPCTargetLowering::LowerVectorLoad(SDValue Op,
SDValue BasePtr = LN->getBasePtr();
EVT VT = Op.getValueType();

if (VT == MVT::v1024i1)
if (VT == MVT::v1024i1 || VT == MVT::v2048i1)
return LowerDMFVectorLoad(Op, DAG);

if (VT != MVT::v256i1 && VT != MVT::v512i1)
Expand Down Expand Up @@ -11996,34 +12028,88 @@ SDValue PPCTargetLowering::LowerDMFVectorStore(SDValue Op,
StoreSDNode *SN = cast<StoreSDNode>(Op.getNode());
SDValue StoreChain = SN->getChain();
SDValue BasePtr = SN->getBasePtr();
SmallVector<SDValue, 4> Values;
SmallVector<SDValue, 4> Stores;
SmallVector<SDValue, 8> Values;
SmallVector<SDValue, 8> Stores;
EVT VT = SN->getValue().getValueType();
bool IsV1024i1 = VT == MVT::v1024i1;
bool IsV2048i1 = VT == MVT::v2048i1;

// Type v1024i1 is used for Dense Math dmr registers.
assert(VT == MVT::v1024i1 && "Unsupported type.");
// The types v1024i1 and v2048i1 are used for Dense Math dmr registers and
// Dense Math dmr pair registers, respectively.
assert((IsV1024i1 || IsV2048i1) && "Unsupported type.");
assert((Subtarget.hasMMA() && Subtarget.isISAFuture()) &&
"Dense Math support required.");
assert(Subtarget.pairedVectorMemops() && "Vector pair support required.");

SDValue Lo(
DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_wacc_lo, dl, MVT::i32)),
0);
SDValue Hi(
DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_wacc_hi, dl, MVT::i32)),
0);
EVT ReturnTypes[] = {MVT::v256i1, MVT::v256i1};
MachineSDNode *ExtNode =
DAG.getMachineNode(PPC::DMXXEXTFDMR512, dl, ReturnTypes, Lo);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
ExtNode = DAG.getMachineNode(PPC::DMXXEXTFDMR512_HI, dl, ReturnTypes, Hi);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
if (IsV1024i1) {
SDValue Lo(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_wacc_lo, dl, MVT::i32)),
0);
SDValue Hi(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_wacc_hi, dl, MVT::i32)),
0);
MachineSDNode *ExtNode =
DAG.getMachineNode(PPC::DMXXEXTFDMR512, dl, ReturnTypes, Lo);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
ExtNode = DAG.getMachineNode(PPC::DMXXEXTFDMR512_HI, dl, ReturnTypes, Hi);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
} else {
// This corresponds to v2048i1 which represents a dmr pair.
SDValue Dmr0(
DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, dl, MVT::v1024i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_dmr0, dl, MVT::i32)),
0);

SDValue Dmr1(
DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, dl, MVT::v1024i1,
Op.getOperand(1),
DAG.getTargetConstant(PPC::sub_dmr1, dl, MVT::i32)),
0);

SDValue Dmr0Lo(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1, Dmr0,
DAG.getTargetConstant(PPC::sub_wacc_lo, dl, MVT::i32)),
0);

SDValue Dmr0Hi(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1, Dmr0,
DAG.getTargetConstant(PPC::sub_wacc_hi, dl, MVT::i32)),
0);

SDValue Dmr1Lo(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1, Dmr1,
DAG.getTargetConstant(PPC::sub_wacc_lo, dl, MVT::i32)),
0);

SDValue Dmr1Hi(DAG.getMachineNode(
TargetOpcode::EXTRACT_SUBREG, dl, MVT::v512i1, Dmr1,
DAG.getTargetConstant(PPC::sub_wacc_hi, dl, MVT::i32)),
0);

MachineSDNode *ExtNode =
DAG.getMachineNode(PPC::DMXXEXTFDMR512, dl, ReturnTypes, Dmr0Lo);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
ExtNode =
DAG.getMachineNode(PPC::DMXXEXTFDMR512_HI, dl, ReturnTypes, Dmr0Hi);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
ExtNode = DAG.getMachineNode(PPC::DMXXEXTFDMR512, dl, ReturnTypes, Dmr1Lo);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
ExtNode =
DAG.getMachineNode(PPC::DMXXEXTFDMR512_HI, dl, ReturnTypes, Dmr1Hi);
Values.push_back(SDValue(ExtNode, 0));
Values.push_back(SDValue(ExtNode, 1));
}

if (Subtarget.isLittleEndian())
std::reverse(Values.begin(), Values.end());
Expand Down Expand Up @@ -12062,7 +12148,7 @@ SDValue PPCTargetLowering::LowerVectorStore(SDValue Op,
SDValue Value2 = SN->getValue();
EVT StoreVT = Value.getValueType();

if (StoreVT == MVT::v1024i1)
if (StoreVT == MVT::v1024i1 || StoreVT == MVT::v2048i1)
return LowerDMFVectorStore(Op, DAG);

if (StoreVT != MVT::v256i1 && StoreVT != MVT::v512i1)
Expand Down
123 changes: 123 additions & 0 deletions llvm/lib/Target/PowerPC/PPCInstrFutureMMA.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,63 @@ multiclass DMR_NEG_UM_M284_XOXORd11188<bits<6> opcode, bits<8> xo, dag IOL,
}
}

class XForm_AT3_T1_AB3<bits<6> opcode, bits<5> o, bits<10> xo, dag OOL, dag IOL,
string asmstr, list<dag> pattern>
: I <opcode, OOL, IOL, asmstr, NoItinerary> {
bits<3> AT;
bits<3> AB;
bits<1> T;

let Pattern = pattern;

let Inst{6-8} = AT{2-0};
let Inst{9} = 0;
let Inst{10} = T;
let Inst{11-15} = o;
let Inst{16-18} = AB{2-0};
let Inst{19-20} = 0;
let Inst{21-30} = xo;
let Inst{31} = 0;
}

class XForm_ATp2_SR5<bits<6> opcode, bits<5> o, bits<10> xo, dag OOL, dag IOL,
string asmstr, list<dag> pattern>
: I <opcode, OOL, IOL, asmstr, NoItinerary> {
bits<2> ATp;
bits<5> SR;

let Pattern = pattern;

let Inst{6-7} = ATp{1-0};
let Inst{8-10} = 0;
let Inst{11-15} = o;
let Inst{16-20} = SR{4-0};
let Inst{21-30} = xo;
let Inst{31} = 0;
}

class XX2Form_AT3_XB6_ID2_E1_BL2<bits<6> opcode, bits<9> xo, dag OOL, dag IOL,
string asmstr, list<dag> pattern>
: I<opcode, OOL, IOL, asmstr, NoItinerary> {
bits<3> AT;
bits<6> XB;
bits<2> ID;
bits<1> E;
bits<2> BL;

let Pattern = pattern;

let Inst{6-8} = AT{2-0};
let Inst{9-10} = 0;
let Inst{11-12} = ID{1-0};
let Inst{13} = E;
let Inst{14-15} = BL{1-0};
let Inst{16-20} = XB{4-0};
let Inst{21-29} = xo;
let Inst{30} = XB{5};
let Inst{31} = 0;
}

let Predicates = [IsISAFuture] in {
def DMXXEXTFDMR512 : XX3Form_AT3_XABp5_P1<60, 226,
(outs vsrprc:$XAp, vsrprc:$XBp),
Expand Down Expand Up @@ -415,6 +472,27 @@ defm DMXVBF16GERX2 : DMR_NEG_UM_M284_XOXORf939a0<59, 74, (ins vsrprc:$XAp, vsrc:
defm DMXVF16GERX2 : DMR_NEG_UM_M284_XOXORd11188<59, 66, (ins vsrprc:$XAp, vsrc:$XB),
"dmxvf16gerx2", "$AT, $XAp, $XB">;

// DMF cryptography [support] Instructions
let Predicates = [IsISAFuture] in {
def DMSHA2HASH :
XForm_AT3_T1_AB3<31, 14, 177, (outs dmr:$AT), (ins dmr:$ATi, dmr:$AB, u1imm:$T),
"dmsha2hash $AT, $AB, $T",
[(set v1024i1:$AT, (int_ppc_mma_dmsha2hash v1024i1:$ATi, v1024i1:$AB, timm:$T))]>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;

def DMSHA3HASH :
XForm_ATp2_SR5<31, 15, 177, (outs dmrprc:$ATp), (ins dmrprc:$ATpi , u5imm:$SR),
"dmsha3hash $ATp, $SR",
[(set v2048i1:$ATp, (int_ppc_mma_dmsha3hash v2048i1:$ATpi, timm:$SR))]>,
RegConstraint<"$ATpi = $ATp">, NoEncode<"$ATpi">;

def DMXXSHAPAD :
XX2Form_AT3_XB6_ID2_E1_BL2<60, 421, (outs dmr:$AT),
(ins dmr:$ATi, vsrc:$XB, u2imm:$ID, u1imm:$E, u2imm:$BL),
"dmxxshapad $AT, $XB, $ID, $E, $BL", []>,
RegConstraint<"$ATi = $AT">, NoEncode<"$ATi">;
}

// MMA+ Intrinsics
let Predicates = [MMA, IsISAFuture] in {
def : Pat<(v1024i1 (int_ppc_mma_dmxvi8gerx4 v256i1:$XAp, v16i8:$XB)),
Expand Down Expand Up @@ -532,3 +610,48 @@ let Predicates = [MMA, PrefixInstrs, IsISAFuture] in {
(PMDMXVF16GERX2NN $ATi, $XAp, RCCp.BToVSRC, Msk8Imm:$XMSK,
Msk4Imm:$YMSK, Msk2Imm:$PMSK)>;
}

// Cryptography Intrinsic
let Predicates = [IsISAFuture] in {
def : Pat<(v1024i1 (int_ppc_mma_dmxxshapad v1024i1:$ATi, v16i8:$XB, timm:$ID,
timm:$E, timm:$BL)), (DMXXSHAPAD $ATi, RCCp.BToVSRC, $ID, $E, $BL)>;
}

// MMA+ Instruction aliases
let Predicates = [IsISAFuture] in {
def : InstAlias<"dmsha256hash $AT, $AB",
(DMSHA2HASH dmr:$AT, dmr:$AB, 0)>;

def : InstAlias<"dmsha512hash $AT, $AB",
(DMSHA2HASH dmr:$AT, dmr:$AB, 1)>;

def : InstAlias<"dmsha3dw $ATp",
(DMSHA3HASH dmrprc:$ATp, 0)>;

def : InstAlias<"dmcryshash $ATp",
(DMSHA3HASH dmrprc:$ATp, 12)>;

def : InstAlias<"dmxxsha3512pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 0, u1imm:$E, 0)>;

def : InstAlias<"dmxxsha3384pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 0, u1imm:$E, 1)>;

def : InstAlias<"dmxxsha3256pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 0, u1imm:$E, 2)>;

def : InstAlias<"dmxxsha3224pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 0, u1imm:$E, 3)>;

def : InstAlias<"dmxxshake256pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 1, u1imm:$E, 0)>;

def : InstAlias<"dmxxshake128pad $AT, $XB, $E",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 1, u1imm:$E, 1)>;

def : InstAlias<"dmxxsha384512pad $AT, $XB",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 2, 0, 0)>;

def : InstAlias<"dmxxsha224256pad $AT, $XB",
(DMXXSHAPAD dmr:$AT, vsrc:$XB, 3, 0, 0)>;
}
2 changes: 1 addition & 1 deletion llvm/lib/Target/PowerPC/PPCRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,6 @@ def PPCRegDMRpRCAsmOperand : AsmOperandClass {
let PredicateMethod = "isDMRpRegNumber";
}

def dmrp : RegisterOperand<DMRpRC> {
def dmrprc : RegisterOperand<DMRpRC> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we renaming this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is dmr pair reg class, similar to

def vsrc : RegisterOperand {
let ParserMatchClass = PPCRegVSRCAsmOperand;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to change all the dmr register classes then? Cause it seems the acc and dmr register def no longer follow that since I see acc, wacc, dmr, dmrrow etc... it's strange to just change the name for dmr pair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will update this in a separate patch to either use *rc for all DMF related register classes or rename the dmrprc back to dmrp.

let ParserMatchClass = PPCRegDMRpRCAsmOperand;
}
Loading