Skip to content

Commit be369ea

Browse files
committed
[AArch64][SVE2] Add the SVE2.1 while & pext predicate pair instructions
This patch adds the assembly/disassembly for the following predicate pair instructions: pext: Set pair of predicates from predicate-as-counter whilelt: While incrementing signed scalar less than scalar whilele: While incrementing signed scalar less than or equal to scalar whilegt: While incrementing signed scalar greater than scalar whilege: While incrementing signed scalar greater than or equal to scalar whilelo: While incrementing unsigned scalar lower than scalar whilels: While incrementing unsigned scalar lower or same as scalar whilehs: While decrementing unsigned scalar higher or same as scalar whilehi: While decrementing unsigned scalar higher than scalar The reference can be found here: https://developer.arm.com/documentation/ddi0602/2022-09 Differential Revision: https://reviews.llvm.org/D136759
1 parent 870fbf8 commit be369ea

27 files changed

+1216
-25
lines changed

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,75 @@ def PNR32_p8to15 : PNRP8to15RegOp<"s", PNRAsmOp32_p8to15, 32, PPR_p8to15>;
953953
def PNR64_p8to15 : PNRP8to15RegOp<"d", PNRAsmOp64_p8to15, 64, PPR_p8to15>;
954954

955955

956+
let Namespace = "AArch64" in {
957+
def psub0 : SubRegIndex<16, -1>;
958+
def psub1 : SubRegIndex<16, -1>;
959+
}
960+
961+
// Pairs of SVE predicate vector registers.
962+
def PSeqPairs : RegisterTuples<[psub0, psub1], [(rotl PPR, 0), (rotl PPR, 1)]>;
963+
964+
def PPR2 : RegisterClass<"AArch64", [untyped], 16, (add PSeqPairs)> {
965+
let Size = 32;
966+
}
967+
968+
class PPRVectorList<int ElementWidth, int NumRegs> : AsmOperandClass {
969+
let Name = "SVEPredicateList" # NumRegs # "x" # ElementWidth;
970+
let ParserMethod = "tryParseVectorList<RegKind::SVEPredicateVector>";
971+
let PredicateMethod = "isTypedVectorList<RegKind::SVEPredicateVector, "
972+
# NumRegs #", 0, "#ElementWidth #">";
973+
let RenderMethod = "addVectorListOperands<AArch64Operand::VecListIdx_PReg, "
974+
# NumRegs #">";
975+
}
976+
977+
def PP_b : RegisterOperand<PPR2, "printTypedVectorList<0,'b'>"> {
978+
let ParserMatchClass = PPRVectorList<8, 2>;
979+
}
980+
981+
def PP_h : RegisterOperand<PPR2, "printTypedVectorList<0,'h'>"> {
982+
let ParserMatchClass = PPRVectorList<16, 2>;
983+
}
984+
985+
def PP_s : RegisterOperand<PPR2, "printTypedVectorList<0,'s'>"> {
986+
let ParserMatchClass = PPRVectorList<32, 2>;
987+
}
988+
989+
def PP_d : RegisterOperand<PPR2, "printTypedVectorList<0,'d'>"> {
990+
let ParserMatchClass = PPRVectorList<64, 2>;
991+
}
992+
993+
// SVE2 multiple-of-2 multi-predicate-vector operands
994+
def PPR2Mul2 : RegisterClass<"AArch64", [untyped], 16, (add (decimate PSeqPairs, 2))> {
995+
let Size = 32;
996+
}
997+
998+
class PPRVectorListMul<int ElementWidth, int NumRegs> : PPRVectorList<ElementWidth, NumRegs> {
999+
let Name = "SVEPredicateListMul" # NumRegs # "x" # ElementWidth;
1000+
let DiagnosticType = "Invalid" # Name;
1001+
let PredicateMethod =
1002+
"isTypedVectorListMultiple<RegKind::SVEPredicateVector, " # NumRegs # ", 0, "
1003+
# ElementWidth # ">";
1004+
}
1005+
1006+
let EncoderMethod = "EncodeRegAsMultipleOf<2>",
1007+
DecoderMethod = "DecodePPR2Mul2RegisterClass" in {
1008+
def PP_b_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'b'>"> {
1009+
let ParserMatchClass = PPRVectorListMul<8, 2>;
1010+
}
1011+
1012+
def PP_h_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'h'>"> {
1013+
let ParserMatchClass = PPRVectorListMul<16, 2>;
1014+
}
1015+
1016+
def PP_s_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'s'>"> {
1017+
let ParserMatchClass = PPRVectorListMul<32, 2>;
1018+
}
1019+
1020+
def PP_d_mul_r : RegisterOperand<PPR2Mul2, "printTypedVectorList<0,'d'>"> {
1021+
let ParserMatchClass = PPRVectorListMul<64, 2>;
1022+
}
1023+
} // end let EncoderMethod/DecoderMethod
1024+
9561025

9571026
//******************************************************************************
9581027

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3591,6 +3591,7 @@ def UDOT_ZZZI_HtoS : sve2p1_two_way_dot_vvi<"udot", 0b1>;
35913591

35923592
defm CNTP_XCI : sve2p1_pcount_pn<"cntp", 0b000>;
35933593
defm PEXT_PCI : sve2p1_pred_as_ctr_to_mask<"pext">;
3594+
defm PEXT_2PCI : sve2p1_pred_as_ctr_to_mask_pair<"pext">;
35943595
defm PTRUE_C : sve2p1_ptrue_pn<"ptrue">;
35953596

35963597
defm SQCVTN_Z2Z_StoH : sve2p1_multi_vec_extract_narrow<"sqcvtn", 0b00>;
@@ -3672,6 +3673,14 @@ defm STNT1H_4Z_IMM : sve2p1_mem_cst_si_4z<"stnt1h", 0b01, 0b1, ZZZZ_h_mul_r>;
36723673
defm STNT1W_4Z_IMM : sve2p1_mem_cst_si_4z<"stnt1w", 0b10, 0b1, ZZZZ_s_mul_r>;
36733674
defm STNT1D_4Z_IMM : sve2p1_mem_cst_si_4z<"stnt1d", 0b11, 0b1, ZZZZ_d_mul_r>;
36743675

3676+
defm WHILEGE_2PXX : sve2p1_int_while_rr_pair<"whilege", 0b000>;
3677+
defm WHILEGT_2PXX : sve2p1_int_while_rr_pair<"whilegt", 0b001>;
3678+
defm WHILELT_2PXX : sve2p1_int_while_rr_pair<"whilelt", 0b010>;
3679+
defm WHILELE_2PXX : sve2p1_int_while_rr_pair<"whilele", 0b011>;
3680+
defm WHILEHS_2PXX : sve2p1_int_while_rr_pair<"whilehs", 0b100>;
3681+
defm WHILEHI_2PXX : sve2p1_int_while_rr_pair<"whilehi", 0b101>;
3682+
defm WHILELO_2PXX : sve2p1_int_while_rr_pair<"whilelo", 0b110>;
3683+
defm WHILELS_2PXX : sve2p1_int_while_rr_pair<"whilels", 0b111>;
36753684
defm WHILEGE_CXX : sve2p1_int_while_rr_pn<"whilege", 0b000>;
36763685
defm WHILEGT_CXX : sve2p1_int_while_rr_pn<"whilegt", 0b001>;
36773686
defm WHILELT_CXX : sve2p1_int_while_rr_pn<"whilelt", 0b010>;

llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class AArch64AsmParser : public MCTargetAsmParser {
225225

226226
bool validateInstruction(MCInst &Inst, SMLoc &IDLoc,
227227
SmallVectorImpl<SMLoc> &Loc);
228+
unsigned getNumRegsForRegKind(RegKind K);
228229
bool MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
229230
OperandVector &Operands, MCStreamer &Out,
230231
uint64_t &ErrorInfo,
@@ -1726,6 +1727,7 @@ class AArch64Operand : public MCParsedAsmOperand {
17261727
VecListIdx_DReg = 0,
17271728
VecListIdx_QReg = 1,
17281729
VecListIdx_ZReg = 2,
1730+
VecListIdx_PReg = 3,
17291731
};
17301732

17311733
template <VecListIndexType RegTy, unsigned NumRegs>
@@ -1740,12 +1742,17 @@ class AArch64Operand : public MCParsedAsmOperand {
17401742
AArch64::Q0_Q1_Q2, AArch64::Q0_Q1_Q2_Q3 },
17411743
/* ZReg */ { AArch64::Z0,
17421744
AArch64::Z0, AArch64::Z0_Z1,
1743-
AArch64::Z0_Z1_Z2, AArch64::Z0_Z1_Z2_Z3 }
1745+
AArch64::Z0_Z1_Z2, AArch64::Z0_Z1_Z2_Z3 },
1746+
/* PReg */ { AArch64::P0,
1747+
AArch64::P0, AArch64::P0_P1 }
17441748
};
17451749

17461750
assert((RegTy != VecListIdx_ZReg || NumRegs <= 4) &&
17471751
" NumRegs must be <= 4 for ZRegs");
17481752

1753+
assert((RegTy != VecListIdx_PReg || NumRegs <= 2) &&
1754+
" NumRegs must be <= 2 for PRegs");
1755+
17491756
unsigned FirstReg = FirstRegs[(unsigned)RegTy][NumRegs];
17501757
Inst.addOperand(MCOperand::createReg(FirstReg + getVectorListStart() -
17511758
FirstRegs[(unsigned)RegTy][0]));
@@ -2807,6 +2814,20 @@ unsigned AArch64AsmParser::matchRegisterNameAlias(StringRef Name,
28072814
return RegNum;
28082815
}
28092816

2817+
unsigned AArch64AsmParser::getNumRegsForRegKind(RegKind K) {
2818+
switch (K) {
2819+
case RegKind::Scalar:
2820+
case RegKind::NeonVector:
2821+
case RegKind::SVEDataVector:
2822+
return 32;
2823+
case RegKind::Matrix:
2824+
case RegKind::SVEPredicateVector:
2825+
case RegKind::SVEPredicateAsCounter:
2826+
return 16;
2827+
}
2828+
llvm_unreachable("Unsupported RegKind");
2829+
}
2830+
28102831
/// tryParseScalarRegister - Try to parse a register name. The token must be an
28112832
/// Identifier when called, and if it is a register name the token is eaten and
28122833
/// the register is added to the operand list.
@@ -4169,6 +4190,7 @@ AArch64AsmParser::tryParseVectorList(OperandVector &Operands,
41694190
return MatchOperand_NoMatch;
41704191
};
41714192

4193+
int NumRegs = getNumRegsForRegKind(VectorKind);
41724194
SMLoc S = getLoc();
41734195
auto LCurly = getTok();
41744196
Lex(); // Eat left bracket token.
@@ -4203,7 +4225,8 @@ AArch64AsmParser::tryParseVectorList(OperandVector &Operands,
42034225
return MatchOperand_ParseFail;
42044226
}
42054227

4206-
unsigned Space = (PrevReg < Reg) ? (Reg - PrevReg) : (Reg + 32 - PrevReg);
4228+
unsigned Space =
4229+
(PrevReg < Reg) ? (Reg - PrevReg) : (Reg + NumRegs - PrevReg);
42074230

42084231
if (Space == 0 || Space > 3) {
42094232
Error(Loc, "invalid number of vectors");
@@ -4229,7 +4252,8 @@ AArch64AsmParser::tryParseVectorList(OperandVector &Operands,
42294252

42304253
// Registers must be incremental (with wraparound at 31)
42314254
if (getContext().getRegisterInfo()->getEncodingValue(Reg) !=
4232-
(getContext().getRegisterInfo()->getEncodingValue(PrevReg) + 1) % 32) {
4255+
(getContext().getRegisterInfo()->getEncodingValue(PrevReg) + 1) %
4256+
NumRegs) {
42334257
Error(Loc, "registers must be sequential");
42344258
return MatchOperand_ParseFail;
42354259
}
@@ -5678,6 +5702,13 @@ bool AArch64AsmParser::showMatchError(SMLoc Loc, unsigned ErrCode,
56785702
"pn0..pn15 with element suffix.");
56795703
case Match_InvalidSVEVecLenSpecifier:
56805704
return Error(Loc, "Invalid vector length specifier, expected VLx2 or VLx4");
5705+
case Match_InvalidSVEPredicateListMul2x8:
5706+
case Match_InvalidSVEPredicateListMul2x16:
5707+
case Match_InvalidSVEPredicateListMul2x32:
5708+
case Match_InvalidSVEPredicateListMul2x64:
5709+
return Error(Loc, "Invalid vector list, expected list with 2 consecutive "
5710+
"predicate registers, where the first vector is a multiple of 2 "
5711+
"and with correct element type");
56815712
case Match_InvalidSVEExactFPImmOperandHalfOne:
56825713
return Error(Loc, "Invalid floating point constant, expected 0.5 or 1.0.");
56835714
case Match_InvalidSVEExactFPImmOperandHalfTwo:
@@ -6262,6 +6293,10 @@ bool AArch64AsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
62626293
case Match_InvalidSVEPNPredicateHReg:
62636294
case Match_InvalidSVEPNPredicateSReg:
62646295
case Match_InvalidSVEPNPredicateDReg:
6296+
case Match_InvalidSVEPredicateListMul2x8:
6297+
case Match_InvalidSVEPredicateListMul2x16:
6298+
case Match_InvalidSVEPredicateListMul2x32:
6299+
case Match_InvalidSVEPredicateListMul2x64:
62656300
case Match_InvalidSVEExactFPImmOperandHalfOne:
62666301
case Match_InvalidSVEExactFPImmOperandHalfTwo:
62676302
case Match_InvalidSVEExactFPImmOperandZeroOne:

llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ static DecodeStatus DecodePPR_3bRegisterClass(MCInst &Inst, unsigned RegNo,
140140
static DecodeStatus
141141
DecodePPR_p8to15RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address,
142142
const MCDisassembler *Decoder);
143+
static DecodeStatus DecodePPR2RegisterClass(MCInst &Inst, unsigned RegNo,
144+
uint64_t Address,
145+
const void *Decoder);
146+
static DecodeStatus DecodePPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo,
147+
uint64_t Address,
148+
const void *Decoder);
143149

144150
static DecodeStatus DecodeFixedPointScaleImm32(MCInst &Inst, unsigned Imm,
145151
uint64_t Address,
@@ -707,6 +713,29 @@ DecodePPR_p8to15RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
707713
return DecodePPRRegisterClass(Inst, RegNo + 8, Addr, Decoder);
708714
}
709715

716+
static DecodeStatus DecodePPR2RegisterClass(MCInst &Inst, unsigned RegNo,
717+
uint64_t Address,
718+
const void *Decoder) {
719+
if (RegNo > 15)
720+
return Fail;
721+
722+
unsigned Register =
723+
AArch64MCRegisterClasses[AArch64::PPR2RegClassID].getRegister(RegNo);
724+
Inst.addOperand(MCOperand::createReg(Register));
725+
return Success;
726+
}
727+
728+
static DecodeStatus DecodePPR2Mul2RegisterClass(MCInst &Inst, unsigned RegNo,
729+
uint64_t Address,
730+
const void *Decoder) {
731+
if ((RegNo * 2) > 14)
732+
return Fail;
733+
unsigned Register =
734+
AArch64MCRegisterClasses[AArch64::PPR2RegClassID].getRegister(RegNo * 2);
735+
Inst.addOperand(MCOperand::createReg(Register));
736+
return Success;
737+
}
738+
710739
static DecodeStatus DecodeQQRegisterClass(MCInst &Inst, unsigned RegNo,
711740
uint64_t Addr,
712741
const MCDisassembler *Decoder) {

llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,23 @@ static unsigned getNextVectorRegister(unsigned Reg, unsigned Stride = 1) {
14151415
case AArch64::Z31:
14161416
Reg = AArch64::Z0;
14171417
break;
1418+
case AArch64::P0: Reg = AArch64::P1; break;
1419+
case AArch64::P1: Reg = AArch64::P2; break;
1420+
case AArch64::P2: Reg = AArch64::P3; break;
1421+
case AArch64::P3: Reg = AArch64::P4; break;
1422+
case AArch64::P4: Reg = AArch64::P5; break;
1423+
case AArch64::P5: Reg = AArch64::P6; break;
1424+
case AArch64::P6: Reg = AArch64::P7; break;
1425+
case AArch64::P7: Reg = AArch64::P8; break;
1426+
case AArch64::P8: Reg = AArch64::P9; break;
1427+
case AArch64::P9: Reg = AArch64::P10; break;
1428+
case AArch64::P10: Reg = AArch64::P11; break;
1429+
case AArch64::P11: Reg = AArch64::P12; break;
1430+
case AArch64::P12: Reg = AArch64::P13; break;
1431+
case AArch64::P13: Reg = AArch64::P14; break;
1432+
case AArch64::P14: Reg = AArch64::P15; break;
1433+
// Vector lists can wrap around.
1434+
case AArch64::P15: Reg = AArch64::P0; break;
14181435
}
14191436
}
14201437
return Reg;
@@ -1477,7 +1494,8 @@ void AArch64InstPrinter::printVectorList(const MCInst *MI, unsigned OpNum,
14771494
unsigned NumRegs = 1;
14781495
if (MRI.getRegClass(AArch64::DDRegClassID).contains(Reg) ||
14791496
MRI.getRegClass(AArch64::ZPR2RegClassID).contains(Reg) ||
1480-
MRI.getRegClass(AArch64::QQRegClassID).contains(Reg))
1497+
MRI.getRegClass(AArch64::QQRegClassID).contains(Reg) ||
1498+
MRI.getRegClass(AArch64::PPR2RegClassID).contains(Reg))
14811499
NumRegs = 2;
14821500
else if (MRI.getRegClass(AArch64::DDDRegClassID).contains(Reg) ||
14831501
MRI.getRegClass(AArch64::ZPR3RegClassID).contains(Reg) ||
@@ -1495,6 +1513,8 @@ void AArch64InstPrinter::printVectorList(const MCInst *MI, unsigned OpNum,
14951513
Reg = FirstReg;
14961514
else if (unsigned FirstReg = MRI.getSubReg(Reg, AArch64::zsub0))
14971515
Reg = FirstReg;
1516+
else if (unsigned FirstReg = MRI.getSubReg(Reg, AArch64::psub0))
1517+
Reg = FirstReg;
14981518

14991519
// If it's a D-reg, we need to promote it to the equivalent Q-reg before
15001520
// printing (otherwise getRegisterName fails).
@@ -1504,7 +1524,9 @@ void AArch64InstPrinter::printVectorList(const MCInst *MI, unsigned OpNum,
15041524
Reg = MRI.getMatchingSuperReg(Reg, AArch64::dsub, &FPR128RC);
15051525
}
15061526

1507-
if (MRI.getRegClass(AArch64::ZPRRegClassID).contains(Reg) && NumRegs > 1 &&
1527+
if ((MRI.getRegClass(AArch64::ZPRRegClassID).contains(Reg) ||
1528+
MRI.getRegClass(AArch64::PPRRegClassID).contains(Reg)) &&
1529+
NumRegs > 1 &&
15081530
// Do not print the range when the last register is lower than the first.
15091531
// Because it is a wrap-around register.
15101532
Reg < getNextVectorRegister(Reg, NumRegs - 1)) {
@@ -1520,7 +1542,8 @@ void AArch64InstPrinter::printVectorList(const MCInst *MI, unsigned OpNum,
15201542
} else {
15211543
for (unsigned i = 0; i < NumRegs; ++i, Reg = getNextVectorRegister(Reg)) {
15221544
// wrap-around sve register
1523-
if (MRI.getRegClass(AArch64::ZPRRegClassID).contains(Reg))
1545+
if (MRI.getRegClass(AArch64::ZPRRegClassID).contains(Reg) ||
1546+
MRI.getRegClass(AArch64::PPRRegClassID).contains(Reg))
15241547
printRegName(O, Reg);
15251548
else
15261549
printRegName(O, Reg, AArch64::vreg);

0 commit comments

Comments
 (0)