Skip to content

Commit f0e5e45

Browse files
Reuse-EmitZAInstr-to-add-Za-Matrix
1 parent 3622a1d commit f0e5e45

File tree

4 files changed

+111
-92
lines changed

4 files changed

+111
-92
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
395395
template <unsigned MaxIdx, unsigned Scale>
396396
void SelectMultiVectorMove(SDNode *N, unsigned NumVecs, unsigned BaseReg,
397397
unsigned Op);
398-
template <unsigned MaxIdx, unsigned Scale>
399-
void SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs, unsigned Op);
398+
void SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs, unsigned Op,
399+
unsigned MaxIdx, unsigned Scale);
400400
bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm);
401401
/// SVE Reg+Imm addressing mode.
402402
template <int64_t Min, int64_t Max>
@@ -2004,9 +2004,9 @@ void AArch64DAGToDAGISel::SelectMultiVectorMove(SDNode *N, unsigned NumVecs,
20042004
CurDAG->RemoveDeadNode(N);
20052005
}
20062006

2007-
template <unsigned MaxIdx, unsigned Scale>
20082007
void AArch64DAGToDAGISel::SelectMultiVectorMoveZ(SDNode *N, unsigned NumVecs,
2009-
unsigned Op) {
2008+
unsigned Op, unsigned MaxIdx,
2009+
unsigned Scale) {
20102010

20112011
SDValue SliceBase = N->getOperand(3);
20122012
SDValue Base, Offset;
@@ -5274,68 +5274,68 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
52745274
}
52755275
case Intrinsic::aarch64_sme_readz_horiz_x2: {
52765276
if (VT == MVT::nxv16i8) {
5277-
SelectMultiVectorMoveZ<14, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_B_PSEUDO);
5277+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_B_PSEUDO, 14, 2);
52785278
return;
52795279
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
52805280
VT == MVT::nxv8bf16) {
5281-
SelectMultiVectorMoveZ<6, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_H_PSEUDO);
5281+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_H_PSEUDO, 6, 2);
52825282
return;
52835283
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5284-
SelectMultiVectorMoveZ<2, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_S_PSEUDO);
5284+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_S_PSEUDO, 2, 2);
52855285
return;
52865286
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5287-
SelectMultiVectorMoveZ<0, 2>(Node, 2, AArch64::MOVAZ_2ZMI_H_D_PSEUDO);
5287+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_H_D_PSEUDO, 0, 2);
52885288
return;
52895289
}
52905290
break;
52915291
}
52925292
case Intrinsic::aarch64_sme_readz_vert_x2: {
52935293
if (VT == MVT::nxv16i8) {
5294-
SelectMultiVectorMoveZ<14, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_B_PSEUDO);
5294+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_B_PSEUDO, 14, 2);
52955295
return;
52965296
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
52975297
VT == MVT::nxv8bf16) {
5298-
SelectMultiVectorMoveZ<6, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_H_PSEUDO);
5298+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_H_PSEUDO, 6, 2);
52995299
return;
53005300
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5301-
SelectMultiVectorMoveZ<2, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_S_PSEUDO);
5301+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_S_PSEUDO, 2, 2);
53025302
return;
53035303
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5304-
SelectMultiVectorMoveZ<0, 2>(Node, 2, AArch64::MOVAZ_2ZMI_V_D_PSEUDO);
5304+
SelectMultiVectorMoveZ(Node, 2, AArch64::MOVAZ_2ZMI_V_D_PSEUDO, 0, 2);
53055305
return;
53065306
}
53075307
break;
53085308
}
53095309
case Intrinsic::aarch64_sme_readz_horiz_x4: {
53105310
if (VT == MVT::nxv16i8) {
5311-
SelectMultiVectorMoveZ<12, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_B_PSEUDO);
5311+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_B_PSEUDO, 12, 4);
53125312
return;
53135313
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
53145314
VT == MVT::nxv8bf16) {
5315-
SelectMultiVectorMoveZ<4, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_H_PSEUDO);
5315+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_H_PSEUDO, 4, 4);
53165316
return;
53175317
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5318-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_S_PSEUDO);
5318+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_S_PSEUDO, 0, 4);
53195319
return;
53205320
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5321-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_H_D_PSEUDO);
5321+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_H_D_PSEUDO, 0, 4);
53225322
return;
53235323
}
53245324
break;
53255325
}
53265326
case Intrinsic::aarch64_sme_readz_vert_x4: {
53275327
if (VT == MVT::nxv16i8) {
5328-
SelectMultiVectorMoveZ<12, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_B_PSEUDO);
5328+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_B_PSEUDO, 12, 4);
53295329
return;
53305330
} else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 ||
53315331
VT == MVT::nxv8bf16) {
5332-
SelectMultiVectorMoveZ<4, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_H_PSEUDO);
5332+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_H_PSEUDO, 4, 4);
53335333
return;
53345334
} else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) {
5335-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_S_PSEUDO);
5335+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_S_PSEUDO, 0, 4);
53365336
return;
53375337
} else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) {
5338-
SelectMultiVectorMoveZ<0, 4>(Node, 4, AArch64::MOVAZ_4ZMI_V_D_PSEUDO);
5338+
SelectMultiVectorMoveZ(Node, 4, AArch64::MOVAZ_4ZMI_V_D_PSEUDO, 0, 4);
53395339
return;
53405340
}
53415341
break;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,19 +2965,28 @@ MachineBasicBlock *AArch64TargetLowering::EmitZTInstr(MachineInstr &MI,
29652965

29662966
MachineBasicBlock *
29672967
AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg,
2968-
MachineInstr &MI,
2969-
MachineBasicBlock *BB, bool HasTile) const {
2968+
MachineInstr &MI, MachineBasicBlock *BB,
2969+
bool HasTile, bool HasZPROut) const {
29702970
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
29712971
MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
29722972
unsigned StartIdx = 0;
29732973

2974-
if (HasTile) {
2975-
MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
2976-
MIB.addReg(BaseReg + MI.getOperand(0).getImm());
2977-
StartIdx = 1;
2978-
} else
2979-
MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
2980-
2974+
if (HasZPROut) {
2975+
if (HasTile) {
2976+
MIB.add(MI.getOperand(0)); // Output ZPR
2977+
MIB.addReg(BaseReg + MI.getOperand(1).getImm(),
2978+
RegState::Define); // Output ZA Tile
2979+
MIB.addReg(BaseReg + MI.getOperand(1).getImm()); // Input Za Tile
2980+
StartIdx = 2;
2981+
}
2982+
} else {
2983+
if (HasTile) {
2984+
MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
2985+
MIB.addReg(BaseReg + MI.getOperand(0).getImm());
2986+
StartIdx = 1;
2987+
} else
2988+
MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
2989+
}
29812990
for (unsigned I = StartIdx; I < MI.getNumOperands(); ++I)
29822991
MIB.add(MI.getOperand(I));
29832992

@@ -3012,17 +3021,59 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
30123021
TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask;
30133022
switch (SMEMatrixType) {
30143023
case (AArch64::SMEMatrixArray):
3015-
return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB, /*HasTile*/ false);
3024+
return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB, /*HasTile*/ false,
3025+
/*HasZPROut*/ false);
30163026
case (AArch64::SMEMatrixTileB):
3017-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB, /*HasTile*/ true);
3027+
switch (MI.getOpcode()) {
3028+
case AArch64::MOVAZ_2ZMI_H_B_PSEUDO:
3029+
case AArch64::MOVAZ_2ZMI_V_B_PSEUDO:
3030+
case AArch64::MOVAZ_4ZMI_H_B_PSEUDO:
3031+
case AArch64::MOVAZ_4ZMI_V_B_PSEUDO:
3032+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB,
3033+
/*HasTile*/ true, /*HasZPROut*/ true);
3034+
default:
3035+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB,
3036+
/*HasTile*/ true, /*HasZPROut*/ false);
3037+
}
30183038
case (AArch64::SMEMatrixTileH):
3019-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB, /*HasTile*/ true);
3039+
switch (MI.getOpcode()) {
3040+
case AArch64::MOVAZ_2ZMI_H_H_PSEUDO:
3041+
case AArch64::MOVAZ_2ZMI_V_H_PSEUDO:
3042+
case AArch64::MOVAZ_4ZMI_H_H_PSEUDO:
3043+
case AArch64::MOVAZ_4ZMI_V_H_PSEUDO:
3044+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB,
3045+
/*HasTile*/ true, /*HasZPROut*/ true);
3046+
default:
3047+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB,
3048+
/*HasTile*/ true, /*HasZPROut*/ false);
3049+
}
30203050
case (AArch64::SMEMatrixTileS):
3021-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB, /*HasTile*/ true);
3051+
switch (MI.getOpcode()) {
3052+
case AArch64::MOVAZ_2ZMI_H_S_PSEUDO:
3053+
case AArch64::MOVAZ_2ZMI_V_S_PSEUDO:
3054+
case AArch64::MOVAZ_4ZMI_H_S_PSEUDO:
3055+
case AArch64::MOVAZ_4ZMI_V_S_PSEUDO:
3056+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB,
3057+
/*HasTile*/ true, /*HasZPROut*/ true);
3058+
default:
3059+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB,
3060+
/*HasTile*/ true, /*HasZPROut*/ false);
3061+
}
30223062
case (AArch64::SMEMatrixTileD):
3023-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB, /*HasTile*/ true);
3063+
switch (MI.getOpcode()) {
3064+
case AArch64::MOVAZ_2ZMI_H_D_PSEUDO:
3065+
case AArch64::MOVAZ_2ZMI_V_D_PSEUDO:
3066+
case AArch64::MOVAZ_4ZMI_H_D_PSEUDO:
3067+
case AArch64::MOVAZ_4ZMI_V_D_PSEUDO:
3068+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB,
3069+
/*HasTile*/ true, /*HasZPROut*/ true);
3070+
default:
3071+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB,
3072+
/*HasTile*/ true, /*HasZPROut*/ false);
3073+
}
30243074
case (AArch64::SMEMatrixTileQ):
3025-
return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB, /*HasTile*/ true);
3075+
return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB, /*HasTile*/ true,
3076+
/*HasZPROut*/ false);
30263077
}
30273078
}
30283079

@@ -3091,38 +3142,6 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
30913142
return EmitZero(MI, BB);
30923143
case AArch64::ZERO_T_PSEUDO:
30933144
return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true);
3094-
case AArch64::MOVAZ_2ZMI_H_B_PSEUDO:
3095-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_B, AArch64::ZAB0, MI, BB);
3096-
case AArch64::MOVAZ_2ZMI_H_H_PSEUDO:
3097-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_H, AArch64::ZAH0, MI, BB);
3098-
case AArch64::MOVAZ_2ZMI_H_S_PSEUDO:
3099-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_S, AArch64::ZAS0, MI, BB);
3100-
case AArch64::MOVAZ_2ZMI_H_D_PSEUDO:
3101-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_H_D, AArch64::ZAD0, MI, BB);
3102-
case AArch64::MOVAZ_2ZMI_V_B_PSEUDO:
3103-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_B, AArch64::ZAB0, MI, BB);
3104-
case AArch64::MOVAZ_2ZMI_V_H_PSEUDO:
3105-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_H, AArch64::ZAH0, MI, BB);
3106-
case AArch64::MOVAZ_2ZMI_V_S_PSEUDO:
3107-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_S, AArch64::ZAS0, MI, BB);
3108-
case AArch64::MOVAZ_2ZMI_V_D_PSEUDO:
3109-
return EmitTileMovaz(AArch64::MOVAZ_2ZMI_V_D, AArch64::ZAD0, MI, BB);
3110-
case AArch64::MOVAZ_4ZMI_H_B_PSEUDO:
3111-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_B, AArch64::ZAB0, MI, BB);
3112-
case AArch64::MOVAZ_4ZMI_H_H_PSEUDO:
3113-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_H, AArch64::ZAH0, MI, BB);
3114-
case AArch64::MOVAZ_4ZMI_H_S_PSEUDO:
3115-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_S, AArch64::ZAS0, MI, BB);
3116-
case AArch64::MOVAZ_4ZMI_H_D_PSEUDO:
3117-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_H_D, AArch64::ZAD0, MI, BB);
3118-
case AArch64::MOVAZ_4ZMI_V_B_PSEUDO:
3119-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_B, AArch64::ZAB0, MI, BB);
3120-
case AArch64::MOVAZ_4ZMI_V_H_PSEUDO:
3121-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_H, AArch64::ZAH0, MI, BB);
3122-
case AArch64::MOVAZ_4ZMI_V_S_PSEUDO:
3123-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_S, AArch64::ZAS0, MI, BB);
3124-
case AArch64::MOVAZ_4ZMI_V_D_PSEUDO:
3125-
return EmitTileMovaz(AArch64::MOVAZ_4ZMI_V_D, AArch64::ZAD0, MI, BB);
31263145
}
31273146
}
31283147

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ class AArch64TargetLowering : public TargetLowering {
654654
MachineBasicBlock *EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const;
655655
MachineBasicBlock *EmitZAInstr(unsigned Opc, unsigned BaseReg,
656656
MachineInstr &MI, MachineBasicBlock *BB,
657-
bool HasTile) const;
657+
bool HasTile, bool HasZPROut) const;
658658
MachineBasicBlock *EmitZTInstr(MachineInstr &MI, MachineBasicBlock *BB,
659659
unsigned Opcode, bool Op0IsDef) const;
660660
MachineBasicBlock *EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const;

0 commit comments

Comments
 (0)