|
18 | 18 | #include "llvm/ADT/APInt.h"
|
19 | 19 | #include "llvm/ADT/ArrayRef.h"
|
20 | 20 | #include "llvm/ADT/STLExtras.h"
|
| 21 | +#include "llvm/ADT/SmallSet.h" |
21 | 22 | #include "llvm/ADT/SmallVector.h"
|
22 | 23 | #include "llvm/ADT/StringExtras.h"
|
23 | 24 | #include "llvm/ADT/StringMap.h"
|
@@ -262,6 +263,7 @@ class AArch64AsmParser : public MCTargetAsmParser {
|
262 | 263 | template <RegKind VectorKind>
|
263 | 264 | OperandMatchResultTy tryParseVectorList(OperandVector &Operands,
|
264 | 265 | bool ExpectMatch = false);
|
| 266 | + OperandMatchResultTy tryParseMatrixTileList(OperandVector &Operands); |
265 | 267 | OperandMatchResultTy tryParseSVEPattern(OperandVector &Operands);
|
266 | 268 | OperandMatchResultTy tryParseGPR64x8(OperandVector &Operands);
|
267 | 269 |
|
@@ -322,6 +324,7 @@ class AArch64Operand : public MCParsedAsmOperand {
|
322 | 324 | k_CondCode,
|
323 | 325 | k_Register,
|
324 | 326 | k_MatrixRegister,
|
| 327 | + k_MatrixTileList, |
325 | 328 | k_SVCR,
|
326 | 329 | k_VectorList,
|
327 | 330 | k_VectorIndex,
|
@@ -383,6 +386,10 @@ class AArch64Operand : public MCParsedAsmOperand {
|
383 | 386 | MatrixKind Kind;
|
384 | 387 | };
|
385 | 388 |
|
| 389 | + struct MatrixTileListOp { |
| 390 | + unsigned RegMask = 0; |
| 391 | + }; |
| 392 | + |
386 | 393 | struct VectorListOp {
|
387 | 394 | unsigned RegNum;
|
388 | 395 | unsigned Count;
|
@@ -460,6 +467,7 @@ class AArch64Operand : public MCParsedAsmOperand {
|
460 | 467 | struct TokOp Tok;
|
461 | 468 | struct RegOp Reg;
|
462 | 469 | struct MatrixRegOp MatrixReg;
|
| 470 | + struct MatrixTileListOp MatrixTileList; |
463 | 471 | struct VectorListOp VectorList;
|
464 | 472 | struct VectorIndexOp VectorIndex;
|
465 | 473 | struct ImmOp Imm;
|
@@ -512,6 +520,9 @@ class AArch64Operand : public MCParsedAsmOperand {
|
512 | 520 | case k_MatrixRegister:
|
513 | 521 | MatrixReg = o.MatrixReg;
|
514 | 522 | break;
|
| 523 | + case k_MatrixTileList: |
| 524 | + MatrixTileList = o.MatrixTileList; |
| 525 | + break; |
515 | 526 | case k_VectorList:
|
516 | 527 | VectorList = o.VectorList;
|
517 | 528 | break;
|
@@ -622,6 +633,11 @@ class AArch64Operand : public MCParsedAsmOperand {
|
622 | 633 | return MatrixReg.Kind;
|
623 | 634 | }
|
624 | 635 |
|
| 636 | + unsigned getMatrixTileListRegMask() const { |
| 637 | + assert(isMatrixTileList() && "Invalid access!"); |
| 638 | + return MatrixTileList.RegMask; |
| 639 | + } |
| 640 | + |
625 | 641 | RegConstraintEqualityTy getRegEqualityTy() const {
|
626 | 642 | assert(Kind == k_Register && "Invalid access!");
|
627 | 643 | return Reg.EqualityTy;
|
@@ -1143,6 +1159,7 @@ class AArch64Operand : public MCParsedAsmOperand {
|
1143 | 1159 | }
|
1144 | 1160 |
|
1145 | 1161 | bool isMatrix() const { return Kind == k_MatrixRegister; }
|
| 1162 | + bool isMatrixTileList() const { return Kind == k_MatrixTileList; } |
1146 | 1163 |
|
1147 | 1164 | template <unsigned Class> bool isSVEVectorReg() const {
|
1148 | 1165 | RegKind RK;
|
@@ -1643,6 +1660,13 @@ class AArch64Operand : public MCParsedAsmOperand {
|
1643 | 1660 | FirstRegs[(unsigned)RegTy][0]));
|
1644 | 1661 | }
|
1645 | 1662 |
|
| 1663 | + void addMatrixTileListOperands(MCInst &Inst, unsigned N) const { |
| 1664 | + assert(N == 1 && "Invalid number of operands!"); |
| 1665 | + unsigned RegMask = getMatrixTileListRegMask(); |
| 1666 | + assert(RegMask <= 0xFF && "Invalid mask!"); |
| 1667 | + Inst.addOperand(MCOperand::createImm(RegMask)); |
| 1668 | + } |
| 1669 | + |
1646 | 1670 | void addVectorIndexOperands(MCInst &Inst, unsigned N) const {
|
1647 | 1671 | assert(N == 1 && "Invalid number of operands!");
|
1648 | 1672 | Inst.addOperand(MCOperand::createImm(getVectorIndex()));
|
@@ -2012,6 +2036,45 @@ class AArch64Operand : public MCParsedAsmOperand {
|
2012 | 2036 | return Op;
|
2013 | 2037 | }
|
2014 | 2038 |
|
| 2039 | + static std::unique_ptr<AArch64Operand> |
| 2040 | + CreateMatrixTileList(unsigned RegMask, SMLoc S, SMLoc E, MCContext &Ctx) { |
| 2041 | + auto Op = std::make_unique<AArch64Operand>(k_MatrixTileList, Ctx); |
| 2042 | + Op->MatrixTileList.RegMask = RegMask; |
| 2043 | + Op->StartLoc = S; |
| 2044 | + Op->EndLoc = E; |
| 2045 | + return Op; |
| 2046 | + } |
| 2047 | + |
| 2048 | + static void ComputeRegsForAlias(unsigned Reg, SmallSet<unsigned, 8> &OutRegs, |
| 2049 | + const unsigned ElementWidth) { |
| 2050 | + static std::map<std::pair<unsigned, unsigned>, std::vector<unsigned>> |
| 2051 | + RegMap = { |
| 2052 | + {{0, AArch64::ZAB0}, |
| 2053 | + {AArch64::ZAD0, AArch64::ZAD1, AArch64::ZAD2, AArch64::ZAD3, |
| 2054 | + AArch64::ZAD4, AArch64::ZAD5, AArch64::ZAD6, AArch64::ZAD7}}, |
| 2055 | + {{8, AArch64::ZAB0}, |
| 2056 | + {AArch64::ZAD0, AArch64::ZAD1, AArch64::ZAD2, AArch64::ZAD3, |
| 2057 | + AArch64::ZAD4, AArch64::ZAD5, AArch64::ZAD6, AArch64::ZAD7}}, |
| 2058 | + {{16, AArch64::ZAH0}, |
| 2059 | + {AArch64::ZAD0, AArch64::ZAD2, AArch64::ZAD4, AArch64::ZAD6}}, |
| 2060 | + {{16, AArch64::ZAH1}, |
| 2061 | + {AArch64::ZAD1, AArch64::ZAD3, AArch64::ZAD5, AArch64::ZAD7}}, |
| 2062 | + {{32, AArch64::ZAS0}, {AArch64::ZAD0, AArch64::ZAD4}}, |
| 2063 | + {{32, AArch64::ZAS1}, {AArch64::ZAD1, AArch64::ZAD5}}, |
| 2064 | + {{32, AArch64::ZAS2}, {AArch64::ZAD2, AArch64::ZAD6}}, |
| 2065 | + {{32, AArch64::ZAS3}, {AArch64::ZAD3, AArch64::ZAD7}}, |
| 2066 | + }; |
| 2067 | + |
| 2068 | + if (ElementWidth == 64) |
| 2069 | + OutRegs.insert(Reg); |
| 2070 | + else { |
| 2071 | + std::vector<unsigned> Regs = RegMap[std::make_pair(ElementWidth, Reg)]; |
| 2072 | + assert(!Regs.empty() && "Invalid tile or element width!"); |
| 2073 | + for (auto OutReg : Regs) |
| 2074 | + OutRegs.insert(OutReg); |
| 2075 | + } |
| 2076 | + } |
| 2077 | + |
2015 | 2078 | static std::unique_ptr<AArch64Operand> CreateImm(const MCExpr *Val, SMLoc S,
|
2016 | 2079 | SMLoc E, MCContext &Ctx) {
|
2017 | 2080 | auto Op = std::make_unique<AArch64Operand>(k_Immediate, Ctx);
|
@@ -2235,6 +2298,15 @@ void AArch64Operand::print(raw_ostream &OS) const {
|
2235 | 2298 | case k_MatrixRegister:
|
2236 | 2299 | OS << "<matrix " << getMatrixReg() << ">";
|
2237 | 2300 | break;
|
| 2301 | + case k_MatrixTileList: { |
| 2302 | + OS << "<matrixlist "; |
| 2303 | + unsigned RegMask = getMatrixTileListRegMask(); |
| 2304 | + unsigned MaxBits = 8; |
| 2305 | + for (unsigned I = MaxBits; I > 0; --I) |
| 2306 | + OS << ((RegMask & (1 << (I - 1))) >> (I - 1)); |
| 2307 | + OS << '>'; |
| 2308 | + break; |
| 2309 | + } |
2238 | 2310 | case k_SVCR: {
|
2239 | 2311 | OS << getSVCR();
|
2240 | 2312 | break;
|
@@ -2418,6 +2490,26 @@ static unsigned matchSVEPredicateVectorRegName(StringRef Name) {
|
2418 | 2490 | .Default(0);
|
2419 | 2491 | }
|
2420 | 2492 |
|
| 2493 | +static unsigned matchMatrixTileListRegName(StringRef Name) { |
| 2494 | + return StringSwitch<unsigned>(Name.lower()) |
| 2495 | + .Case("za0.d", AArch64::ZAD0) |
| 2496 | + .Case("za1.d", AArch64::ZAD1) |
| 2497 | + .Case("za2.d", AArch64::ZAD2) |
| 2498 | + .Case("za3.d", AArch64::ZAD3) |
| 2499 | + .Case("za4.d", AArch64::ZAD4) |
| 2500 | + .Case("za5.d", AArch64::ZAD5) |
| 2501 | + .Case("za6.d", AArch64::ZAD6) |
| 2502 | + .Case("za7.d", AArch64::ZAD7) |
| 2503 | + .Case("za0.s", AArch64::ZAS0) |
| 2504 | + .Case("za1.s", AArch64::ZAS1) |
| 2505 | + .Case("za2.s", AArch64::ZAS2) |
| 2506 | + .Case("za3.s", AArch64::ZAS3) |
| 2507 | + .Case("za0.h", AArch64::ZAH0) |
| 2508 | + .Case("za1.h", AArch64::ZAH1) |
| 2509 | + .Case("za0.b", AArch64::ZAB0) |
| 2510 | + .Default(0); |
| 2511 | +} |
| 2512 | + |
2421 | 2513 | static unsigned matchMatrixRegName(StringRef Name) {
|
2422 | 2514 | return StringSwitch<unsigned>(Name.lower())
|
2423 | 2515 | .Case("za", AArch64::ZA)
|
@@ -3763,6 +3855,120 @@ bool AArch64AsmParser::parseSymbolicImmVal(const MCExpr *&ImmVal) {
|
3763 | 3855 | return false;
|
3764 | 3856 | }
|
3765 | 3857 |
|
| 3858 | +OperandMatchResultTy |
| 3859 | +AArch64AsmParser::tryParseMatrixTileList(OperandVector &Operands) { |
| 3860 | + MCAsmParser &Parser = getParser(); |
| 3861 | + |
| 3862 | + if (Parser.getTok().isNot(AsmToken::LCurly)) |
| 3863 | + return MatchOperand_NoMatch; |
| 3864 | + |
| 3865 | + auto ParseMatrixTile = [this, &Parser](unsigned &Reg, |
| 3866 | + unsigned &ElementWidth) { |
| 3867 | + StringRef Name = Parser.getTok().getString(); |
| 3868 | + size_t DotPosition = Name.find('.'); |
| 3869 | + if (DotPosition == StringRef::npos) |
| 3870 | + return MatchOperand_NoMatch; |
| 3871 | + |
| 3872 | + unsigned RegNum = matchMatrixTileListRegName(Name); |
| 3873 | + if (!RegNum) |
| 3874 | + return MatchOperand_NoMatch; |
| 3875 | + |
| 3876 | + StringRef Tail = Name.drop_front(DotPosition); |
| 3877 | + const Optional<std::pair<int, int>> &KindRes = |
| 3878 | + parseVectorKind(Tail, RegKind::Matrix); |
| 3879 | + if (!KindRes) { |
| 3880 | + TokError("Expected the register to be followed by element width suffix"); |
| 3881 | + return MatchOperand_ParseFail; |
| 3882 | + } |
| 3883 | + ElementWidth = KindRes->second; |
| 3884 | + Reg = RegNum; |
| 3885 | + Parser.Lex(); // Eat the register. |
| 3886 | + return MatchOperand_Success; |
| 3887 | + }; |
| 3888 | + |
| 3889 | + SMLoc S = getLoc(); |
| 3890 | + auto LCurly = Parser.getTok(); |
| 3891 | + Parser.Lex(); // Eat left bracket token. |
| 3892 | + |
| 3893 | + // Empty matrix list |
| 3894 | + if (parseOptionalToken(AsmToken::RCurly)) { |
| 3895 | + Operands.push_back(AArch64Operand::CreateMatrixTileList( |
| 3896 | + /*RegMask=*/0, S, getLoc(), getContext())); |
| 3897 | + return MatchOperand_Success; |
| 3898 | + } |
| 3899 | + |
| 3900 | + // Try parse {za} alias early |
| 3901 | + if (Parser.getTok().getString().equals_insensitive("za")) { |
| 3902 | + Parser.Lex(); // Eat 'za' |
| 3903 | + |
| 3904 | + if (parseToken(AsmToken::RCurly, "'}' expected")) |
| 3905 | + return MatchOperand_ParseFail; |
| 3906 | + |
| 3907 | + Operands.push_back(AArch64Operand::CreateMatrixTileList( |
| 3908 | + /*RegMask=*/0xFF, S, getLoc(), getContext())); |
| 3909 | + return MatchOperand_Success; |
| 3910 | + } |
| 3911 | + |
| 3912 | + SMLoc TileLoc = getLoc(); |
| 3913 | + |
| 3914 | + unsigned FirstReg, ElementWidth; |
| 3915 | + auto ParseRes = ParseMatrixTile(FirstReg, ElementWidth); |
| 3916 | + if (ParseRes != MatchOperand_Success) { |
| 3917 | + Parser.getLexer().UnLex(LCurly); |
| 3918 | + return ParseRes; |
| 3919 | + } |
| 3920 | + |
| 3921 | + const MCRegisterInfo *RI = getContext().getRegisterInfo(); |
| 3922 | + |
| 3923 | + unsigned PrevReg = FirstReg; |
| 3924 | + unsigned Count = 1; |
| 3925 | + |
| 3926 | + SmallSet<unsigned, 8> DRegs; |
| 3927 | + AArch64Operand::ComputeRegsForAlias(FirstReg, DRegs, ElementWidth); |
| 3928 | + |
| 3929 | + SmallSet<unsigned, 8> SeenRegs; |
| 3930 | + SeenRegs.insert(FirstReg); |
| 3931 | + |
| 3932 | + while (parseOptionalToken(AsmToken::Comma)) { |
| 3933 | + TileLoc = getLoc(); |
| 3934 | + unsigned Reg, NextElementWidth; |
| 3935 | + ParseRes = ParseMatrixTile(Reg, NextElementWidth); |
| 3936 | + if (ParseRes != MatchOperand_Success) |
| 3937 | + return ParseRes; |
| 3938 | + |
| 3939 | + // Element size must match on all regs in the list. |
| 3940 | + if (ElementWidth != NextElementWidth) { |
| 3941 | + Error(TileLoc, "mismatched register size suffix"); |
| 3942 | + return MatchOperand_ParseFail; |
| 3943 | + } |
| 3944 | + |
| 3945 | + if (RI->getEncodingValue(Reg) <= (RI->getEncodingValue(PrevReg))) |
| 3946 | + Warning(TileLoc, "tile list not in ascending order"); |
| 3947 | + |
| 3948 | + if (SeenRegs.contains(Reg)) |
| 3949 | + Warning(TileLoc, "duplicate tile in list"); |
| 3950 | + else { |
| 3951 | + SeenRegs.insert(Reg); |
| 3952 | + AArch64Operand::ComputeRegsForAlias(Reg, DRegs, ElementWidth); |
| 3953 | + } |
| 3954 | + |
| 3955 | + PrevReg = Reg; |
| 3956 | + ++Count; |
| 3957 | + } |
| 3958 | + |
| 3959 | + if (parseToken(AsmToken::RCurly, "'}' expected")) |
| 3960 | + return MatchOperand_ParseFail; |
| 3961 | + |
| 3962 | + unsigned RegMask = 0; |
| 3963 | + for (auto Reg : DRegs) |
| 3964 | + RegMask |= 0x1 << (RI->getEncodingValue(Reg) - |
| 3965 | + RI->getEncodingValue(AArch64::ZAD0)); |
| 3966 | + Operands.push_back( |
| 3967 | + AArch64Operand::CreateMatrixTileList(RegMask, S, getLoc(), getContext())); |
| 3968 | + |
| 3969 | + return MatchOperand_Success; |
| 3970 | +} |
| 3971 | + |
3766 | 3972 | template <RegKind VectorKind>
|
3767 | 3973 | OperandMatchResultTy
|
3768 | 3974 | AArch64AsmParser::tryParseVectorList(OperandVector &Operands,
|
|
0 commit comments