Skip to content

Commit 2e27c4e

Browse files
committed
[AArch64][SME] Add zero instruction
This patch adds the zero instruction for zeroing a list of 64-bit element ZA tiles. The instruction takes a list of up to eight tiles ZA0.D-ZA7.D, which must be in order, e.g. zero {za0.d,za1.d,za2.d,za3.d,za4.d,za5.d,za6.d,za7.d} zero {za1.d,za3.d,za5.d,za7.d} The assembler also accepts 32-bit, 16-bit and 8-bit element tiles which are mapped to corresponding 64-bit element tiles in accordance with the architecturally defined mapping between different element size tiles, e.g. * Zeroing ZA0.B, or the entire array name ZA, is equivalent to zeroing all eight 64-bit element tiles ZA0.D to ZA7.D. * Zeroing ZA0.S is equivalent to zeroing ZA0.D and ZA4.D. The preferred disassembly of this instruction uses the shortest list of tile names that represent the encoded immediate mask, e.g. * An immediate which encodes 64-bit element tiles ZA0.D, ZA1.D, ZA4.D and ZA5.D is disassembled as {ZA0.S, ZA1.S}. * An immediate which encodes 64-bit element tiles ZA0.D, ZA2.D, ZA4.D and ZA6.D is disassembled as {ZA0.H}. * An all-ones immediate is disassembled as {ZA}. * An all-zeros immediate is disassembled as an empty list {}. This patch adds the MatrixTileList asm operand and related parsing to support this. Depends on D105570. The reference can be found here: https://developer.arm.com/documentation/ddi0602/2021-06 Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D105575
1 parent 80e0266 commit 2e27c4e

File tree

10 files changed

+650
-0
lines changed

10 files changed

+650
-0
lines changed

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,22 @@ class MatrixOperand<RegisterClass RC, int EltSize> : RegisterOperand<RC> {
13521352

13531353
def MatrixOp : MatrixOperand<MPR, 0>;
13541354

1355+
class MatrixTileListAsmOperand : AsmOperandClass {
1356+
let Name = "MatrixTileList";
1357+
let ParserMethod = "tryParseMatrixTileList";
1358+
let RenderMethod = "addMatrixTileListOperands";
1359+
let PredicateMethod = "isMatrixTileList";
1360+
}
1361+
1362+
class MatrixTileListOperand : Operand<i8> {
1363+
let ParserMatchClass = MatrixTileListAsmOperand<>;
1364+
let DecoderMethod = "DecodeMatrixTileListRegisterClass";
1365+
let EncoderMethod = "EncodeMatrixTileListRegisterClass";
1366+
let PrintMethod = "printMatrixTileList";
1367+
}
1368+
1369+
def MatrixTileList : MatrixTileListOperand<>;
1370+
13551371
def MatrixIndexGPR32_12_15 : RegisterClass<"AArch64", [i32], 32, (sequence "W%u", 12, 15)> {
13561372
let DiagnosticType = "InvalidMatrixIndexGPR32_12_15";
13571373
}

llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ defm STR_ZA : sme_spill<"str">;
8888
defm INSERT_MXIPZ : sme_vector_to_tile<"mova">;
8989
defm EXTRACT_ZPMXI : sme_tile_to_vector<"mova">;
9090

91+
//===----------------------------------------------------------------------===//
92+
// Zero instruction
93+
//===----------------------------------------------------------------------===//
94+
95+
defm ZERO_M : sme_zero<"zero">;
96+
9197
//===----------------------------------------------------------------------===//
9298
// Mode selection and state access instructions
9399
//===----------------------------------------------------------------------===//

llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/ADT/APInt.h"
1919
#include "llvm/ADT/ArrayRef.h"
2020
#include "llvm/ADT/STLExtras.h"
21+
#include "llvm/ADT/SmallSet.h"
2122
#include "llvm/ADT/SmallVector.h"
2223
#include "llvm/ADT/StringExtras.h"
2324
#include "llvm/ADT/StringMap.h"
@@ -262,6 +263,7 @@ class AArch64AsmParser : public MCTargetAsmParser {
262263
template <RegKind VectorKind>
263264
OperandMatchResultTy tryParseVectorList(OperandVector &Operands,
264265
bool ExpectMatch = false);
266+
OperandMatchResultTy tryParseMatrixTileList(OperandVector &Operands);
265267
OperandMatchResultTy tryParseSVEPattern(OperandVector &Operands);
266268
OperandMatchResultTy tryParseGPR64x8(OperandVector &Operands);
267269

@@ -322,6 +324,7 @@ class AArch64Operand : public MCParsedAsmOperand {
322324
k_CondCode,
323325
k_Register,
324326
k_MatrixRegister,
327+
k_MatrixTileList,
325328
k_SVCR,
326329
k_VectorList,
327330
k_VectorIndex,
@@ -383,6 +386,10 @@ class AArch64Operand : public MCParsedAsmOperand {
383386
MatrixKind Kind;
384387
};
385388

389+
struct MatrixTileListOp {
390+
unsigned RegMask = 0;
391+
};
392+
386393
struct VectorListOp {
387394
unsigned RegNum;
388395
unsigned Count;
@@ -460,6 +467,7 @@ class AArch64Operand : public MCParsedAsmOperand {
460467
struct TokOp Tok;
461468
struct RegOp Reg;
462469
struct MatrixRegOp MatrixReg;
470+
struct MatrixTileListOp MatrixTileList;
463471
struct VectorListOp VectorList;
464472
struct VectorIndexOp VectorIndex;
465473
struct ImmOp Imm;
@@ -512,6 +520,9 @@ class AArch64Operand : public MCParsedAsmOperand {
512520
case k_MatrixRegister:
513521
MatrixReg = o.MatrixReg;
514522
break;
523+
case k_MatrixTileList:
524+
MatrixTileList = o.MatrixTileList;
525+
break;
515526
case k_VectorList:
516527
VectorList = o.VectorList;
517528
break;
@@ -622,6 +633,11 @@ class AArch64Operand : public MCParsedAsmOperand {
622633
return MatrixReg.Kind;
623634
}
624635

636+
unsigned getMatrixTileListRegMask() const {
637+
assert(isMatrixTileList() && "Invalid access!");
638+
return MatrixTileList.RegMask;
639+
}
640+
625641
RegConstraintEqualityTy getRegEqualityTy() const {
626642
assert(Kind == k_Register && "Invalid access!");
627643
return Reg.EqualityTy;
@@ -1143,6 +1159,7 @@ class AArch64Operand : public MCParsedAsmOperand {
11431159
}
11441160

11451161
bool isMatrix() const { return Kind == k_MatrixRegister; }
1162+
bool isMatrixTileList() const { return Kind == k_MatrixTileList; }
11461163

11471164
template <unsigned Class> bool isSVEVectorReg() const {
11481165
RegKind RK;
@@ -1643,6 +1660,13 @@ class AArch64Operand : public MCParsedAsmOperand {
16431660
FirstRegs[(unsigned)RegTy][0]));
16441661
}
16451662

1663+
void addMatrixTileListOperands(MCInst &Inst, unsigned N) const {
1664+
assert(N == 1 && "Invalid number of operands!");
1665+
unsigned RegMask = getMatrixTileListRegMask();
1666+
assert(RegMask <= 0xFF && "Invalid mask!");
1667+
Inst.addOperand(MCOperand::createImm(RegMask));
1668+
}
1669+
16461670
void addVectorIndexOperands(MCInst &Inst, unsigned N) const {
16471671
assert(N == 1 && "Invalid number of operands!");
16481672
Inst.addOperand(MCOperand::createImm(getVectorIndex()));
@@ -2012,6 +2036,45 @@ class AArch64Operand : public MCParsedAsmOperand {
20122036
return Op;
20132037
}
20142038

2039+
static std::unique_ptr<AArch64Operand>
2040+
CreateMatrixTileList(unsigned RegMask, SMLoc S, SMLoc E, MCContext &Ctx) {
2041+
auto Op = std::make_unique<AArch64Operand>(k_MatrixTileList, Ctx);
2042+
Op->MatrixTileList.RegMask = RegMask;
2043+
Op->StartLoc = S;
2044+
Op->EndLoc = E;
2045+
return Op;
2046+
}
2047+
2048+
static void ComputeRegsForAlias(unsigned Reg, SmallSet<unsigned, 8> &OutRegs,
2049+
const unsigned ElementWidth) {
2050+
static std::map<std::pair<unsigned, unsigned>, std::vector<unsigned>>
2051+
RegMap = {
2052+
{{0, AArch64::ZAB0},
2053+
{AArch64::ZAD0, AArch64::ZAD1, AArch64::ZAD2, AArch64::ZAD3,
2054+
AArch64::ZAD4, AArch64::ZAD5, AArch64::ZAD6, AArch64::ZAD7}},
2055+
{{8, AArch64::ZAB0},
2056+
{AArch64::ZAD0, AArch64::ZAD1, AArch64::ZAD2, AArch64::ZAD3,
2057+
AArch64::ZAD4, AArch64::ZAD5, AArch64::ZAD6, AArch64::ZAD7}},
2058+
{{16, AArch64::ZAH0},
2059+
{AArch64::ZAD0, AArch64::ZAD2, AArch64::ZAD4, AArch64::ZAD6}},
2060+
{{16, AArch64::ZAH1},
2061+
{AArch64::ZAD1, AArch64::ZAD3, AArch64::ZAD5, AArch64::ZAD7}},
2062+
{{32, AArch64::ZAS0}, {AArch64::ZAD0, AArch64::ZAD4}},
2063+
{{32, AArch64::ZAS1}, {AArch64::ZAD1, AArch64::ZAD5}},
2064+
{{32, AArch64::ZAS2}, {AArch64::ZAD2, AArch64::ZAD6}},
2065+
{{32, AArch64::ZAS3}, {AArch64::ZAD3, AArch64::ZAD7}},
2066+
};
2067+
2068+
if (ElementWidth == 64)
2069+
OutRegs.insert(Reg);
2070+
else {
2071+
std::vector<unsigned> Regs = RegMap[std::make_pair(ElementWidth, Reg)];
2072+
assert(!Regs.empty() && "Invalid tile or element width!");
2073+
for (auto OutReg : Regs)
2074+
OutRegs.insert(OutReg);
2075+
}
2076+
}
2077+
20152078
static std::unique_ptr<AArch64Operand> CreateImm(const MCExpr *Val, SMLoc S,
20162079
SMLoc E, MCContext &Ctx) {
20172080
auto Op = std::make_unique<AArch64Operand>(k_Immediate, Ctx);
@@ -2235,6 +2298,15 @@ void AArch64Operand::print(raw_ostream &OS) const {
22352298
case k_MatrixRegister:
22362299
OS << "<matrix " << getMatrixReg() << ">";
22372300
break;
2301+
case k_MatrixTileList: {
2302+
OS << "<matrixlist ";
2303+
unsigned RegMask = getMatrixTileListRegMask();
2304+
unsigned MaxBits = 8;
2305+
for (unsigned I = MaxBits; I > 0; --I)
2306+
OS << ((RegMask & (1 << (I - 1))) >> (I - 1));
2307+
OS << '>';
2308+
break;
2309+
}
22382310
case k_SVCR: {
22392311
OS << getSVCR();
22402312
break;
@@ -2418,6 +2490,26 @@ static unsigned matchSVEPredicateVectorRegName(StringRef Name) {
24182490
.Default(0);
24192491
}
24202492

2493+
static unsigned matchMatrixTileListRegName(StringRef Name) {
2494+
return StringSwitch<unsigned>(Name.lower())
2495+
.Case("za0.d", AArch64::ZAD0)
2496+
.Case("za1.d", AArch64::ZAD1)
2497+
.Case("za2.d", AArch64::ZAD2)
2498+
.Case("za3.d", AArch64::ZAD3)
2499+
.Case("za4.d", AArch64::ZAD4)
2500+
.Case("za5.d", AArch64::ZAD5)
2501+
.Case("za6.d", AArch64::ZAD6)
2502+
.Case("za7.d", AArch64::ZAD7)
2503+
.Case("za0.s", AArch64::ZAS0)
2504+
.Case("za1.s", AArch64::ZAS1)
2505+
.Case("za2.s", AArch64::ZAS2)
2506+
.Case("za3.s", AArch64::ZAS3)
2507+
.Case("za0.h", AArch64::ZAH0)
2508+
.Case("za1.h", AArch64::ZAH1)
2509+
.Case("za0.b", AArch64::ZAB0)
2510+
.Default(0);
2511+
}
2512+
24212513
static unsigned matchMatrixRegName(StringRef Name) {
24222514
return StringSwitch<unsigned>(Name.lower())
24232515
.Case("za", AArch64::ZA)
@@ -3763,6 +3855,120 @@ bool AArch64AsmParser::parseSymbolicImmVal(const MCExpr *&ImmVal) {
37633855
return false;
37643856
}
37653857

3858+
OperandMatchResultTy
3859+
AArch64AsmParser::tryParseMatrixTileList(OperandVector &Operands) {
3860+
MCAsmParser &Parser = getParser();
3861+
3862+
if (Parser.getTok().isNot(AsmToken::LCurly))
3863+
return MatchOperand_NoMatch;
3864+
3865+
auto ParseMatrixTile = [this, &Parser](unsigned &Reg,
3866+
unsigned &ElementWidth) {
3867+
StringRef Name = Parser.getTok().getString();
3868+
size_t DotPosition = Name.find('.');
3869+
if (DotPosition == StringRef::npos)
3870+
return MatchOperand_NoMatch;
3871+
3872+
unsigned RegNum = matchMatrixTileListRegName(Name);
3873+
if (!RegNum)
3874+
return MatchOperand_NoMatch;
3875+
3876+
StringRef Tail = Name.drop_front(DotPosition);
3877+
const Optional<std::pair<int, int>> &KindRes =
3878+
parseVectorKind(Tail, RegKind::Matrix);
3879+
if (!KindRes) {
3880+
TokError("Expected the register to be followed by element width suffix");
3881+
return MatchOperand_ParseFail;
3882+
}
3883+
ElementWidth = KindRes->second;
3884+
Reg = RegNum;
3885+
Parser.Lex(); // Eat the register.
3886+
return MatchOperand_Success;
3887+
};
3888+
3889+
SMLoc S = getLoc();
3890+
auto LCurly = Parser.getTok();
3891+
Parser.Lex(); // Eat left bracket token.
3892+
3893+
// Empty matrix list
3894+
if (parseOptionalToken(AsmToken::RCurly)) {
3895+
Operands.push_back(AArch64Operand::CreateMatrixTileList(
3896+
/*RegMask=*/0, S, getLoc(), getContext()));
3897+
return MatchOperand_Success;
3898+
}
3899+
3900+
// Try parse {za} alias early
3901+
if (Parser.getTok().getString().equals_insensitive("za")) {
3902+
Parser.Lex(); // Eat 'za'
3903+
3904+
if (parseToken(AsmToken::RCurly, "'}' expected"))
3905+
return MatchOperand_ParseFail;
3906+
3907+
Operands.push_back(AArch64Operand::CreateMatrixTileList(
3908+
/*RegMask=*/0xFF, S, getLoc(), getContext()));
3909+
return MatchOperand_Success;
3910+
}
3911+
3912+
SMLoc TileLoc = getLoc();
3913+
3914+
unsigned FirstReg, ElementWidth;
3915+
auto ParseRes = ParseMatrixTile(FirstReg, ElementWidth);
3916+
if (ParseRes != MatchOperand_Success) {
3917+
Parser.getLexer().UnLex(LCurly);
3918+
return ParseRes;
3919+
}
3920+
3921+
const MCRegisterInfo *RI = getContext().getRegisterInfo();
3922+
3923+
unsigned PrevReg = FirstReg;
3924+
unsigned Count = 1;
3925+
3926+
SmallSet<unsigned, 8> DRegs;
3927+
AArch64Operand::ComputeRegsForAlias(FirstReg, DRegs, ElementWidth);
3928+
3929+
SmallSet<unsigned, 8> SeenRegs;
3930+
SeenRegs.insert(FirstReg);
3931+
3932+
while (parseOptionalToken(AsmToken::Comma)) {
3933+
TileLoc = getLoc();
3934+
unsigned Reg, NextElementWidth;
3935+
ParseRes = ParseMatrixTile(Reg, NextElementWidth);
3936+
if (ParseRes != MatchOperand_Success)
3937+
return ParseRes;
3938+
3939+
// Element size must match on all regs in the list.
3940+
if (ElementWidth != NextElementWidth) {
3941+
Error(TileLoc, "mismatched register size suffix");
3942+
return MatchOperand_ParseFail;
3943+
}
3944+
3945+
if (RI->getEncodingValue(Reg) <= (RI->getEncodingValue(PrevReg)))
3946+
Warning(TileLoc, "tile list not in ascending order");
3947+
3948+
if (SeenRegs.contains(Reg))
3949+
Warning(TileLoc, "duplicate tile in list");
3950+
else {
3951+
SeenRegs.insert(Reg);
3952+
AArch64Operand::ComputeRegsForAlias(Reg, DRegs, ElementWidth);
3953+
}
3954+
3955+
PrevReg = Reg;
3956+
++Count;
3957+
}
3958+
3959+
if (parseToken(AsmToken::RCurly, "'}' expected"))
3960+
return MatchOperand_ParseFail;
3961+
3962+
unsigned RegMask = 0;
3963+
for (auto Reg : DRegs)
3964+
RegMask |= 0x1 << (RI->getEncodingValue(Reg) -
3965+
RI->getEncodingValue(AArch64::ZAD0));
3966+
Operands.push_back(
3967+
AArch64Operand::CreateMatrixTileList(RegMask, S, getLoc(), getContext()));
3968+
3969+
return MatchOperand_Success;
3970+
}
3971+
37663972
template <RegKind VectorKind>
37673973
OperandMatchResultTy
37683974
AArch64AsmParser::tryParseVectorList(OperandVector &Operands,

llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ static DecodeStatus DecodeZPR4RegisterClass(MCInst &Inst, unsigned RegNo,
118118
template <unsigned NumBitsForTile>
119119
static DecodeStatus DecodeMatrixTile(MCInst &Inst, unsigned RegNo,
120120
uint64_t Address, const void *Decoder);
121+
static DecodeStatus DecodeMatrixTileListRegisterClass(MCInst &Inst,
122+
unsigned RegMask,
123+
uint64_t Address,
124+
const void *Decoder);
121125
static DecodeStatus DecodePPRRegisterClass(MCInst &Inst, unsigned RegNo,
122126
uint64_t Address,
123127
const void *Decoder);
@@ -704,6 +708,16 @@ static DecodeStatus DecodeZPR4RegisterClass(MCInst &Inst, unsigned RegNo,
704708
return Success;
705709
}
706710

711+
static DecodeStatus DecodeMatrixTileListRegisterClass(MCInst &Inst,
712+
unsigned RegMask,
713+
uint64_t Address,
714+
const void *Decoder) {
715+
if (RegMask > 0xFF)
716+
return Fail;
717+
Inst.addOperand(MCOperand::createImm(RegMask));
718+
return Success;
719+
}
720+
707721
static const SmallVector<SmallVector<unsigned, 16>, 5>
708722
MatrixZATileDecoderTable = {
709723
{AArch64::ZAB0},

llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,36 @@ void AArch64InstPrinter::printGPRSeqPairsClassOperand(const MCInst *MI,
13401340
O << getRegisterName(Even) << ", " << getRegisterName(Odd);
13411341
}
13421342

1343+
static const unsigned MatrixZADRegisterTable[] = {
1344+
AArch64::ZAD0, AArch64::ZAD1, AArch64::ZAD2, AArch64::ZAD3,
1345+
AArch64::ZAD4, AArch64::ZAD5, AArch64::ZAD6, AArch64::ZAD7
1346+
};
1347+
1348+
void AArch64InstPrinter::printMatrixTileList(const MCInst *MI, unsigned OpNum,
1349+
const MCSubtargetInfo &STI,
1350+
raw_ostream &O) {
1351+
unsigned MaxRegs = 8;
1352+
unsigned RegMask = MI->getOperand(OpNum).getImm();
1353+
1354+
unsigned NumRegs = 0;
1355+
for (unsigned I = 0; I < MaxRegs; ++I)
1356+
if ((RegMask & (1 << I)) != 0)
1357+
++NumRegs;
1358+
1359+
O << "{";
1360+
unsigned Printed = 0;
1361+
for (unsigned I = 0; I < MaxRegs; ++I) {
1362+
unsigned Reg = RegMask & (1 << I);
1363+
if (Reg == 0)
1364+
continue;
1365+
O << getRegisterName(MatrixZADRegisterTable[I]);
1366+
if (Printed + 1 != NumRegs)
1367+
O << ", ";
1368+
++Printed;
1369+
}
1370+
O << "}";
1371+
}
1372+
13431373
void AArch64InstPrinter::printVectorList(const MCInst *MI, unsigned OpNum,
13441374
const MCSubtargetInfo &STI,
13451375
raw_ostream &O,

0 commit comments

Comments
 (0)