Skip to content

Commit ba423c0

Browse files
Reuse-EmitZAInstr-to-add-Za-Matrix
1 parent 6d6de45 commit ba423c0

File tree

4 files changed

+111
-92
lines changed

4 files changed

+111
-92
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
395395
template <unsigned MaxIdx, unsigned Scale>
396396
void SelectMultiVectorMove(SDNode *N, unsigned NumVecs, unsigned BaseReg,
397397
unsigned Op);
398-
template <unsigned MaxIdx, unsigned Scale>
399-
void SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs, unsigned Op);
398+
void SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs, unsigned Op,
399+
unsigned MaxIdx, unsigned Scale);
400400
bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm);
401401
/// SVE Reg+Imm addressing mode.
402402
template <int64_t Min, int64_t Max>
@@ -2004,9 +2004,9 @@ void AArch64DAGToDAGISel::SelectMultiVectorMove(SDNode *N, unsigned NumVecs,
20042004
CurDAG->RemoveDeadNode(N);
20052005
}
20062006

2007-
template <unsigned MaxIdx, unsigned Scale>
20082007
void AArch64DAGToDAGISel::SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs,
2009-
unsigned Op) {
2008+
unsigned Op, unsigned MaxIdx,
2009+
unsigned Scale) {
20102010

20112011
SDValue SliceBase = N->getOperand(3);
20122012
SDValue Base, Offset;
@@ -5276,68 +5276,68 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
52765276
}
52775277
case Intrinsic::aarch64_sme_readz_horiz_x2: {
52785278
if (VT == MVT::nxv16i8) {
5279-
SelectMultiVectorMoveZ<14, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_B_PSEUDO);
5279+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_B_PSEUDO, 14, 2);
52805280
return;
52815281
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
52825282
VT == MVT::nxv8bf16) {
5283-
SelectMultiVectorMoveZ<6, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_H_PSEUDO);
5283+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_H_PSEUDO, 6, 2);
52845284
return;
52855285
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5286-
SelectMultiVectorMoveZ<2, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_S_PSEUDO);
5286+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_S_PSEUDO, 2, 2);
52875287
return;
52885288
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5289-
SelectMultiVectorMoveZ<0, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_D_PSEUDO);
5289+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_D_PSEUDO, 0, 2);
52905290
return;
52915291
}
52925292
break;
52935293
}
52945294
case Intrinsic::aarch64_sme_readz_vert_x2: {
52955295
if (VT == MVT::nxv16i8) {
5296-
SelectMultiVectorMoveZ<14, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_B_PSEUDO);
5296+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_B_PSEUDO, 14, 2);
52975297
return;
52985298
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
52995299
VT == MVT::nxv8bf16) {
5300-
SelectMultiVectorMoveZ<6, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_H_PSEUDO);
5300+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_H_PSEUDO, 6, 2);
53015301
return;
53025302
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5303-
SelectMultiVectorMoveZ<2, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_S_PSEUDO);
5303+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_S_PSEUDO, 2, 2);
53045304
return;
53055305
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5306-
SelectMultiVectorMoveZ<0, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_D_PSEUDO);
5306+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_D_PSEUDO, 0, 2);
53075307
return;
53085308
}
53095309
break;
53105310
}
53115311
case Intrinsic::aarch64_sme_readz_horiz_x4: {
53125312
if (VT == MVT::nxv16i8) {
5313-
SelectMultiVectorMoveZ<12, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_B_PSEUDO);
5313+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_B_PSEUDO, 12, 4);
53145314
return;
53155315
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
53165316
VT == MVT::nxv8bf16) {
5317-
SelectMultiVectorMoveZ<4, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_H_PSEUDO);
5317+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_H_PSEUDO, 4, 4);
53185318
return;
53195319
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5320-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_S_PSEUDO);
5320+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_S_PSEUDO, 0, 4);
53215321
return;
53225322
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5323-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_D_PSEUDO);
5323+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_D_PSEUDO, 0, 4);
53245324
return;
53255325
}
53265326
break;
53275327
}
53285328
case Intrinsic::aarch64_sme_readz_vert_x4: {
53295329
if (VT == MVT::nxv16i8) {
5330-
SelectMultiVectorMoveZ<12, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_B_PSEUDO);
5330+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_B_PSEUDO, 12, 4);
53315331
return;
53325332
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
53335333
VT == MVT::nxv8bf16) {
5334-
SelectMultiVectorMoveZ<4, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_H_PSEUDO);
5334+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_H_PSEUDO, 4, 4);
53355335
return;
53365336
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5337-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_S_PSEUDO);
5337+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_S_PSEUDO, 0, 4);
53385338
return;
53395339
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5340-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_D_PSEUDO);
5340+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_D_PSEUDO, 0, 4);
53415341
return;
53425342
}
53435343
break;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,19 +2992,28 @@ MachineBasicBlock *AArch64TargetLowering::EmitZTInstr(MachineInstr &MI,
29922992

29932993
MachineBasicBlock *
29942994
AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg,
2995-
MachineInstr &MI,
2996-
MachineBasicBlock *BB, bool HasTile) const {
2995+
MachineInstr &MI, MachineBasicBlock *BB,
2996+
bool HasTile, bool HasZPROut) const {
29972997
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
29982998
MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
29992999
unsigned StartIdx = 0;
30003000

3001-
if (HasTile) {
3002-
MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
3003-
MIB.addReg(BaseReg + MI.getOperand(0).getImm());
3004-
StartIdx = 1;
3005-
} else
3006-
MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
3007-
3001+
if (HasZPROut) {
3002+
if (HasTile) {
3003+
MIB.add(MI.getOperand(0)); // Output ZPR
3004+
MIB.addReg(BaseReg + MI.getOperand(1).getImm(),
3005+
RegState::Define); // Output ZA Tile
3006+
MIB.addReg(BaseReg + MI.getOperand(1).getImm()); // Input Za Tile
3007+
StartIdx = 2;
3008+
}
3009+
} else {
3010+
if (HasTile) {
3011+
MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
3012+
MIB.addReg(BaseReg + MI.getOperand(0).getImm());
3013+
StartIdx = 1;
3014+
} else
3015+
MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
3016+
}
30083017
for (unsigned I = StartIdx; I < MI.getNumOperands(); ++I)
30093018
MIB.add(MI.getOperand(I));
30103019

@@ -3113,17 +3122,59 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
31133122
TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask;
31143123
switch (SMEMatrixType) {
31153124
case (AArch64::SMEMatrixArray):
3116-
return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB, /*HasTile*/ false);
3125+
return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB, /*HasTile*/ false,
3126+
/*HasZPROut*/ false);
31173127
case (AArch64::SMEMatrixTileB):
3118-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB, /*HasTile*/ true);
3128+
switch (MI.getOpcode()) {
3129+
case AArch64::MOVAZ_2ZMI_H_B_PSEUDO:
3130+
case AArch64::MOVAZ_2ZMI_V_B_PSEUDO:
3131+
case AArch64::MOVAZ_4ZMI_H_B_PSEUDO:
3132+
case AArch64::MOVAZ_4ZMI_V_B_PSEUDO:
3133+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB,
3134+
/*HasTile*/ true, /*HasZPROut*/ true);
3135+
default:
3136+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB,
3137+
/*HasTile*/ true, /*HasZPROut*/ false);
3138+
}
31193139
case (AArch64::SMEMatrixTileH):
3120-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB, /*HasTile*/ true);
3140+
switch (MI.getOpcode()) {
3141+
case AArch64::MOVAZ_2ZMI_H_H_PSEUDO:
3142+
case AArch64::MOVAZ_2ZMI_V_H_PSEUDO:
3143+
case AArch64::MOVAZ_4ZMI_H_H_PSEUDO:
3144+
case AArch64::MOVAZ_4ZMI_V_H_PSEUDO:
3145+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB,
3146+
/*HasTile*/ true, /*HasZPROut*/ true);
3147+
default:
3148+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB,
3149+
/*HasTile*/ true, /*HasZPROut*/ false);
3150+
}
31213151
case (AArch64::SMEMatrixTileS):
3122-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB, /*HasTile*/ true);
3152+
switch (MI.getOpcode()) {
3153+
case AArch64::MOVAZ_2ZMI_H_S_PSEUDO:
3154+
case AArch64::MOVAZ_2ZMI_V_S_PSEUDO:
3155+
case AArch64::MOVAZ_4ZMI_H_S_PSEUDO:
3156+
case AArch64::MOVAZ_4ZMI_V_S_PSEUDO:
3157+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB,
3158+
/*HasTile*/ true, /*HasZPROut*/ true);
3159+
default:
3160+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB,
3161+
/*HasTile*/ true, /*HasZPROut*/ false);
3162+
}
31233163
case (AArch64::SMEMatrixTileD):
3124-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB, /*HasTile*/ true);
3164+
switch (MI.getOpcode()) {
3165+
case AArch64::MOVAZ_2ZMI_H_D_PSEUDO:
3166+
case AArch64::MOVAZ_2ZMI_V_D_PSEUDO:
3167+
case AArch64::MOVAZ_4ZMI_H_D_PSEUDO:
3168+
case AArch64::MOVAZ_4ZMI_V_D_PSEUDO:
3169+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB,
3170+
/*HasTile*/ true, /*HasZPROut*/ true);
3171+
default:
3172+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB,
3173+
/*HasTile*/ true, /*HasZPROut*/ false);
3174+
}
31253175
case (AArch64::SMEMatrixTileQ):
3126-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB, /*HasTile*/ true);
3176+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB, /*HasTile*/ true,
3177+
/*HasZPROut*/ false);
31273178
}
31283179
}
31293180

@@ -3195,38 +3246,6 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
31953246
return EmitZero(MI, BB);
31963247
case AArch64::ZERO_T_PSEUDO:
31973248
return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true);
3198-
case AArch64::MOVAZ_2ZMI_H_B_PSEUDO:
3199-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_B, AArch64::ZAB0, MI, BB);
3200-
case AArch64::MOVAZ_2ZMI_H_H_PSEUDO:
3201-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_H, AArch64::ZAH0, MI, BB);
3202-
case AArch64::MOVAZ_2ZMI_H_S_PSEUDO:
3203-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_S, AArch64::ZAS0, MI, BB);
3204-
case AArch64::MOVAZ_2ZMI_H_D_PSEUDO:
3205-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_D, AArch64::ZAD0, MI, BB);
3206-
case AArch64::MOVAZ_2ZMI_V_B_PSEUDO:
3207-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_B, AArch64::ZAB0, MI, BB);
3208-
case AArch64::MOVAZ_2ZMI_V_H_PSEUDO:
3209-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_H, AArch64::ZAH0, MI, BB);
3210-
case AArch64::MOVAZ_2ZMI_V_S_PSEUDO:
3211-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_S, AArch64::ZAS0, MI, BB);
3212-
case AArch64::MOVAZ_2ZMI_V_D_PSEUDO:
3213-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_D, AArch64::ZAD0, MI, BB);
3214-
case AArch64::MOVAZ_4ZMI_H_B_PSEUDO:
3215-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_B, AArch64::ZAB0, MI, BB);
3216-
case AArch64::MOVAZ_4ZMI_H_H_PSEUDO:
3217-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_H, AArch64::ZAH0, MI, BB);
3218-
case AArch64::MOVAZ_4ZMI_H_S_PSEUDO:
3219-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_S, AArch64::ZAS0, MI, BB);
3220-
case AArch64::MOVAZ_4ZMI_H_D_PSEUDO:
3221-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_D, AArch64::ZAD0, MI, BB);
3222-
case AArch64::MOVAZ_4ZMI_V_B_PSEUDO:
3223-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_B, AArch64::ZAB0, MI, BB);
3224-
case AArch64::MOVAZ_4ZMI_V_H_PSEUDO:
3225-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_H, AArch64::ZAH0, MI, BB);
3226-
case AArch64::MOVAZ_4ZMI_V_S_PSEUDO:
3227-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_S, AArch64::ZAS0, MI, BB);
3228-
case AArch64::MOVAZ_4ZMI_V_D_PSEUDO:
3229-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_D, AArch64::ZAD0, MI, BB);
32303249
}
32313250
}
32323251

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ class AArch64TargetLowering : public TargetLowering {
659659
MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const;
660660
MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg,
661661
MachineInstr &MI, MachineBasicBlock *BB,
662-
bool HasTile) const;
662+
bool HasTile, bool HasZPROut) const;
663663
MachineBasicBlock *EmitZTInstr(MachineInstr &MI, MachineBasicBlock *BB,
664664
unsigned Opcode, bool Op0IsDef) const;
665665
MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const;

0 commit comments

Comments
 (0)