Skip to content

[llvm][AArch64][Assembly]: Add SME_F8F16 and SME_F8F32 Ass/Disass. #70640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 3, 2023

Conversation

hassnaaHamdi
Copy link
Member

@hassnaaHamdi hassnaaHamdi commented Oct 30, 2023

This patch adds the feature flags of SME_F8F16 and SME_F8F32,
and the assembly/disassembly for the following instructions of SME2:

  • SME:
    • FMLAL, FMLALL
    • FVDOT, FVDOTT
    • FVDOTB
    • FMOPA

That is according to this documentation:
https://developer.arm.com/documentation/ddi0602/2023-09

@llvmbot llvmbot added backend:AArch64 mc Machine (object) code labels Oct 30, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2023

@llvm/pr-subscribers-mc

@llvm/pr-subscribers-backend-aarch64

Author: None (hassnaaHamdi)

Changes

This patch adds the feature flags of SME_F8F16 and SME_F8F32,
and the assembly/disassembly for the following instructions of SME2:

  • SME:
    • FMLAL, FMLALL
    • FVDOT, FVDOTT
    • FVDOTB
    • FMOPA

Patch is 85.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70640.diff

18 Files Affected:

  • (modified) llvm/include/llvm/TargetParser/AArch64TargetParser.h (+4)
  • (modified) llvm/include/llvm/TargetParser/SubtargetFeature.h (+1-1)
  • (modified) llvm/lib/Target/AArch64/AArch64.td (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.td (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+54)
  • (modified) llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp (+11-6)
  • (modified) llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+155-5)
  • (added) llvm/test/MC/AArch64/FP8_SME2/dot-diagnostics.s (+215)
  • (added) llvm/test/MC/AArch64/FP8_SME2/dot.s (+361)
  • (added) llvm/test/MC/AArch64/FP8_SME2/mla-diagnostics.s (+195)
  • (added) llvm/test/MC/AArch64/FP8_SME2/mla.s (+361)
  • (added) llvm/test/MC/AArch64/FP8_SME2/mopa-diagnostics.s (+46)
  • (added) llvm/test/MC/AArch64/FP8_SME2/mopa.s (+39)
  • (modified) llvm/test/MC/AArch64/SME2/fmlal-diagnostics.s (+2-2)
  • (modified) llvm/test/MC/AArch64/SME2/fvdot-diagnostics.s (+2-2)
  • (modified) llvm/unittests/TargetParser/TargetParserTest.cpp (+6-1)
diff --git a/llvm/include/llvm/TargetParser/AArch64TargetParser.h b/llvm/include/llvm/TargetParser/AArch64TargetParser.h
index 232b3d6a6dbb1c4..6d15b6ecad35a46 100644
--- a/llvm/include/llvm/TargetParser/AArch64TargetParser.h
+++ b/llvm/include/llvm/TargetParser/AArch64TargetParser.h
@@ -162,6 +162,8 @@ enum ArchExtKind : unsigned {
   AEK_FPMR =          58, // FEAT_FPMR
   AEK_FP8 =           59, // FEAT_FP8
   AEK_FAMINMAX =      60, // FEAT_FAMINMAX
+  AEK_SMEF8F16 =      61, // FEAT_SME_F8F16
+  AEK_SMEF8F32 =      62, // FEAT_SME_F8F32
   AEK_NUM_EXTENSIONS
 };
 using ExtensionBitset = Bitset<AEK_NUM_EXTENSIONS>;
@@ -273,6 +275,8 @@ inline constexpr ExtensionInfo Extensions[] = {
     {"fpmr", AArch64::AEK_FPMR, "+fpmr", "-fpmr", FEAT_INIT, "", 0},
     {"fp8", AArch64::AEK_FP8, "+fp8", "-fp8", FEAT_INIT, "+fpmr", 0},
     {"faminmax", AArch64::AEK_FAMINMAX, "+faminmax", "-faminmax", FEAT_INIT, "", 0},
+    {"sme-f8f16", AArch64::AEK_SMEF8F16, "+sme-f8f16", "-sme-f8f16", FEAT_INIT, "+sme2,+fp8", 0},
+    {"sme-f8f32", AArch64::AEK_SMEF8F32, "+sme-f8f32", "-sme-f8f32", FEAT_INIT, "+sme2,+fp8", 0},
     // Special cases
     {"none", AArch64::AEK_NONE, {}, {}, FEAT_INIT, "", ExtensionInfo::MaxFMVPriority},
 };
diff --git a/llvm/include/llvm/TargetParser/SubtargetFeature.h b/llvm/include/llvm/TargetParser/SubtargetFeature.h
index e4dddfb78effbcd..2e1f00dad2df365 100644
--- a/llvm/include/llvm/TargetParser/SubtargetFeature.h
+++ b/llvm/include/llvm/TargetParser/SubtargetFeature.h
@@ -31,7 +31,7 @@ namespace llvm {
 class raw_ostream;
 class Triple;
 
-const unsigned MAX_SUBTARGET_WORDS = 4;
+const unsigned MAX_SUBTARGET_WORDS = 5;
 const unsigned MAX_SUBTARGET_FEATURES = MAX_SUBTARGET_WORDS * 64;
 
 /// Container class for subtarget features.
diff --git a/llvm/lib/Target/AArch64/AArch64.td b/llvm/lib/Target/AArch64/AArch64.td
index 8fd9358c9f9c7a0..805b7f9a9ad88ab 100644
--- a/llvm/lib/Target/AArch64/AArch64.td
+++ b/llvm/lib/Target/AArch64/AArch64.td
@@ -517,6 +517,12 @@ def FeatureSME2p1 : SubtargetFeature<"sme2p1", "HasSME2p1", "true",
 def FeatureFAMINMAX: SubtargetFeature<"faminmax", "HasFAMINMAX", "true",
    "Enable FAMIN and FAMAX instructions (FEAT_FAMINMAX)">;
 
+def FeatureSMEF8F16 : SubtargetFeature<"sme-f8f16", "HasSMEF8F16", "true",
+  "Enable Scalable Matrix Extension (SME) F8F16 instructions(FEAT_SME_F8F16)", [FeatureSME2, FeatureFP8]>;
+
+def FeatureSMEF8F32 : SubtargetFeature<"sme-f8f32", "HasSMEF8F32", "true",
+  "Enable Scalable Matrix Extension (SME) F8F32 instructions (FEAT_SME_F8F32)", [FeatureSME2, FeatureFP8]>;
+
 def FeatureAppleA7SysReg  : SubtargetFeature<"apple-a7-sysreg", "HasAppleA7SysReg", "true",
   "Apple A7 (the CPU formerly known as Cyclone)">;
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 97c5e89a29a9728..ddb856def3d167e 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -166,6 +166,10 @@ def HasFP8           : Predicate<"Subtarget->hasFP8()">,
                                  AssemblerPredicateWithAll<(all_of FeatureFP8), "fp8">;
 def HasFAMINMAX      : Predicate<"Subtarget->hasFAMINMAX()">,
                                  AssemblerPredicateWithAll<(all_of FeatureFAMINMAX), "faminmax">;
+def HasSMEF8F16     : Predicate<"Subtarget->hasSMEF8F16()">,
+                                 AssemblerPredicateWithAll<(all_of FeatureSMEF8F16), "sme-f8f16">;
+def HasSMEF8F32     : Predicate<"Subtarget->hasSMEF8F32()">,
+                                 AssemblerPredicateWithAll<(all_of FeatureSMEF8F32), "sme-f8f32">;
 
 // A subset of SVE(2) instructions are legal in Streaming SVE execution mode,
 // they should be enabled if either has been specified.
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index eb26591908fd79c..57ad51641a2f41e 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -1252,6 +1252,10 @@ class ZPRVectorListMul<int ElementWidth, int NumRegs> : ZPRVectorList<ElementWid
 
 let EncoderMethod = "EncodeRegAsMultipleOf<2>",
     DecoderMethod = "DecodeZPR2Mul2RegisterClass" in {
+  def ZZ_mul_r : RegisterOperand<ZPR2Mul2, "printTypedVectorList<0,0>"> {
+    let ParserMatchClass = ZPRVectorListMul<0, 2>;
+  }
+
   def ZZ_b_mul_r : RegisterOperand<ZPR2Mul2, "printTypedVectorList<0,'b'>"> {
     let ParserMatchClass = ZPRVectorListMul<8, 2>;
   }
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index f55b84b02f85162..1c77d15ba17b232 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -885,3 +885,57 @@ defm FAMIN_2Z2Z : sme2_fp_sve_destructive_vector_vg2_multi<"famin", 0b0010101>;
 defm FAMAX_4Z4Z : sme2_fp_sve_destructive_vector_vg4_multi<"famax", 0b0010100>;
 defm FAMIN_4Z4Z : sme2_fp_sve_destructive_vector_vg4_multi<"famin", 0b0010101>;
 } //[HasSME2, HasFAMINMAX]
+
+let Predicates = [HasSME2, HasSMEF8F16] in {
+defm FVDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_16b<"fvdot", 0b11, 0b110, ZZ_b_mul_r, ZPR4b8>;
+
+defm FDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_16b<"fdot",    0b11, 0b010, ZZ_b_mul_r, ZPR4b8>;
+defm FDOT_VG4_M4ZZI_BtoH : sme2p1_multi_vec_array_vg4_index_16b<"fdot",    0b100, ZZZZ_b_mul_r, ZPR4b8>;
+defm FDOT_VG2_M2ZZ_BtoH  :  sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>;
+defm FDOT_VG4_M4ZZ_BtoH  :  sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>;
+// TODO: Replace nxv16i8 by nxv16f8
+defm FDOT_VG2_M2Z2Z_BtoH : sme2_dot_mla_add_sub_array_vg2_multi<"fdot",    0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
+defm FDOT_VG4_M4Z4Z_BtoH : sme2_dot_mla_add_sub_array_vg4_multi<"fdot",    0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+
+def  FMLAL_MZZI_BtoH      : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>;
+defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>;
+defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>;
+def  FMLAL_VG2_MZZ_BtoH   : sme2_mla_long_array_single_16b<"fmlal">;
+// TODO: Replace nxv16i8 by nxv16f8
+defm FMLAL_VG2_M2ZZ_BtoH  : sme2_fp_mla_long_array_vg2_single<"fmlal",  0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>;
+defm FMLAL_VG4_M4ZZ_BtoH  :  sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>;
+defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal",   0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
+defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal",   0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+
+defm FMOPA_MPPZZ_BtoH : sme2p1_fmop_tile_fp16<"fmopa", 0b1, 0b0, 0b01, ZPR8>;
+
+} //[HasSME2p1, HasSMEF8F16]
+
+let Predicates = [HasSME2, HasSMEF8F32] in {
+// TODO : Replace nxv16i8 by nxv16f8
+defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
+defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
+defm FDOT_VG2_M2ZZ_BtoS  : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>;
+defm FDOT_VG4_M4ZZ_BtoS  : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>;
+// TODO : Replace nxv16i8 by nxv16f8
+defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot",   0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
+defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot",   0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+
+def FVDOTB_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdotb", 0b0>;
+def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>;
+
+defm FMLALL_MZZI_BtoS      : sme2_mla_ll_array_index_32b<"fmlall",     0b01, 0b000, null_frag>;
+defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>;
+defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>;
+// TODO: Replace nxv16i8 by nxv16f8
+defm FMLALL_MZZ_BtoS       : sme2_mla_ll_array_single<"fmlall",      0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>;
+defm FMLALL_VG2_M2ZZ_BtoS  : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>;
+defm FMLALL_VG4_M4ZZ_BtoS  : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>;
+defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall",   0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
+defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall",   0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
+
+
+defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>;
+
+} //[HasSME2p1, HasSMEF8F32]
+
diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
index 35abe3563eb81ab..8a49b6c4e3439b5 100644
--- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
+++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
@@ -3641,6 +3641,8 @@ static const struct Extension {
     {"fpmr", {AArch64::FeatureFPMR}},
     {"fp8", {AArch64::FeatureFP8}},
     {"faminmax", {AArch64::FeatureFAMINMAX}},
+    {"sme-f8f16", {AArch64::FeatureSMEF8F16}},
+    {"sme-f8f32", {AArch64::FeatureSMEF8F32}},
 };
 
 static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) {
@@ -4536,7 +4538,7 @@ ParseStatus AArch64AsmParser::tryParseZTOperand(OperandVector &Operands) {
 
   Operands.push_back(AArch64Operand::CreateReg(
       RegNum, RegKind::LookupTable, StartLoc, getLoc(), getContext()));
-  Lex(); // Eat identifier token.
+  Lex(); // eat register
 
   // Check if register is followed by an index
   if (parseOptionalToken(AsmToken::LBrac)) {
@@ -4546,18 +4548,21 @@ ParseStatus AArch64AsmParser::tryParseZTOperand(OperandVector &Operands) {
     if (getParser().parseExpression(ImmVal))
       return ParseStatus::NoMatch;
     const MCConstantExpr *MCE = dyn_cast<MCConstantExpr>(ImmVal);
+    Operands.push_back(AArch64Operand::CreateImm(
+        MCConstantExpr::create(MCE->getValue(), getContext()), StartLoc,
+        getLoc(), getContext()));
     if (!MCE)
       return TokError("immediate value expected for vector index");
+    if (getTok().is(AsmToken::Comma)) {
+      Lex(); // eat comma
+      if (parseOptionalMulOperand(Operands))
+        return MatchOperand_ParseFail;
+    }
     if (parseToken(AsmToken::RBrac, "']' expected"))
       return ParseStatus::Failure;
-
-    Operands.push_back(AArch64Operand::CreateImm(
-        MCConstantExpr::create(MCE->getValue(), getContext()), StartLoc,
-        getLoc(), getContext()));
     Operands.push_back(
         AArch64Operand::CreateToken("]", getLoc(), getContext()));
   }
-
   return ParseStatus::Success;
 }
 
diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
index 988c78699179f0c..c5de5b4de4aef3a 100644
--- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
+++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64InstPrinter.cpp
@@ -1740,6 +1740,10 @@ template <unsigned NumLanes, char LaneKind>
 void AArch64InstPrinter::printTypedVectorList(const MCInst *MI, unsigned OpNum,
                                               const MCSubtargetInfo &STI,
                                               raw_ostream &O) {
+  if (LaneKind == 0) {
+    printVectorList(MI, OpNum, STI, O, "");
+    return;
+  }
   std::string Suffix(".");
   if (NumLanes)
     Suffix += itostr(NumLanes) + LaneKind;
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index d8b44c68fbdee10..c6faf223611d7f7 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -1922,6 +1922,17 @@ multiclass sme2_mla_long_array_single<string mnemonic, bits<2> op0, bits<2> op,
   def : SME2_ZA_TwoOp_Multi_Single_Pat<NAME # _HtoS, intrinsic, uimm3s2range, ZPR4b16, zpr_ty, tileslicerange3s2>;
 }
 
+class sme2_mla_long_array_single_16b<string mnemonic>
+    : sme2_mla_long_array<0b00, 0b00, MatrixOp16, uimm3s2range, ZPR8, ZPR4b8,  mnemonic> {
+    bits<4> Zm;
+    bits<5> Zn;
+    bits<3> imm;
+    let Inst{20}    = 0b1;
+    let Inst{19-16} = Zm;
+    let Inst{9-5}   = Zn;
+    let Inst{2-0}   = imm;
+}
+
 class sme2_mla_long_array_vg24_single<bits<2> op0, bit vg4, bits<2> op, bit o2,
                                       MatrixOperand matrix_ty, RegisterOperand multi_vector_ty,
                                       ZPRRegOp zpr_ty, string mnemonic, string vg_acronym>
@@ -1937,7 +1948,6 @@ class sme2_mla_long_array_vg24_single<bits<2> op0, bit vg4, bits<2> op, bit o2,
   let Inst{1-0}   = imm;
 }
 
-	
 multiclass sme2_fp_mla_long_array_vg2_single<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
                                              RegisterOperand multi_vector_ty, ZPRRegOp vector_ty,
                                              ValueType zpr_ty, SDPatternOperator intrinsic> {
@@ -1971,7 +1981,8 @@ multiclass sme2_fp_mla_long_array_vg4_single<string mnemonic, bits<3> op, Matrix
                                              RegisterOperand multi_vector_ty, ZPRRegOp vector_ty,
                                              ValueType zpr_ty, SDPatternOperator intrinsic> {
   def NAME : sme2_mla_long_array_vg24_single<0b00, 0b1, op{2-1}, op{0}, matrix_ty, multi_vector_ty, 
-                                             vector_ty, mnemonic, "vgx4">, SMEPseudo2Instr<NAME, 1>;
+                                             vector_ty, mnemonic, "vgx4">,
+                                             SMEPseudo2Instr<NAME, 1>;
 
   def _PSEUDO : sme2_za_array_2op_multi_single_pseudo<NAME, uimm2s2range, multi_vector_ty, vector_ty,
                                                       SMEMatrixArray>;
@@ -2014,7 +2025,6 @@ class sme2_mla_long_array_vg2_multi<string mnemonic, bits<2> op0, bits<3> op,
 multiclass sme2_fp_mla_long_array_vg2_multi<string mnemonic, bits<3> op, MatrixOperand matrix_ty,
                                             RegisterOperand multi_vector_ty,
                                             ValueType zpr_ty, SDPatternOperator intrinsic> {
-
   def NAME : sme2_mla_long_array_vg2_multi<mnemonic, 0b10, op, matrix_ty, multi_vector_ty>,
                                            SMEPseudo2Instr<NAME, 1>;
 
@@ -2390,7 +2400,6 @@ multiclass sme2_zip_vector_vg2<string mnemonic, bit op> {
 
 //===----------------------------------------------------------------------===//
 // SME2 Dot Products and MLA
-
 class sme2_multi_vec_array_vg2_index<bits<2> sz, bits<6> op, MatrixOperand matrix_ty,
                                      RegisterOperand multi_vector_ty,
                                      ZPRRegOp vector_ty, Operand index_ty,
@@ -2428,7 +2437,6 @@ multiclass sme2_multi_vec_array_vg2_index_32b<string mnemonic, bits<2> sz, bits<
     bits<2> i;
     let Inst{11-10} = i;
   }
-
   def _PSEUDO : sme2_za_array_2op_multi_index_pseudo<NAME, sme_elm_idx0_7, multi_vector_ty, vector_ty, VectorIndexS32b_timm, SMEMatrixArray>;
 
   def : SME2_ZA_TwoOp_VG2_Multi_Index_Pat<NAME, intrinsic, sme_elm_idx0_7, vector_ty, vt, VectorIndexS32b_timm, tileslice16>;
@@ -2439,6 +2447,7 @@ multiclass sme2_multi_vec_array_vg2_index_32b<string mnemonic, bits<2> sz, bits<
 }
 
 // SME2.1 multi-vec ternary indexed two registers 16-bit
+// SME2 multi-vec indexed FP8 two-way dot product to FP16 two registers
 multiclass sme2p1_multi_vec_array_vg2_index_16b<string mnemonic, bits<2> sz, bits<3> op,
                                                 RegisterOperand multi_vector_ty, ZPRRegOp zpr_ty> {
   def NAME : sme2_multi_vec_array_vg2_index<sz, {op{2},?,?,op{1-0},?}, MatrixOp16,
@@ -2448,11 +2457,24 @@ multiclass sme2p1_multi_vec_array_vg2_index_16b<string mnemonic, bits<2> sz, bit
     let Inst{11-10} = i{2-1};
     let Inst{3}     = i{0};
   }
+
   def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm3], $Zn, $Zm$i",
         (!cast<Instruction>(NAME) MatrixOp16:$ZAda,  MatrixIndexGPR32Op8_11:$Rv, sme_elm_idx0_7:$imm3,
         multi_vector_ty:$Zn, zpr_ty:$Zm, VectorIndexH:$i), 0>;
 }
 
+// SME2 multi-vec indexed FP8 two-way vertical dot product to single precision
+// two registers
+class sme2_fp8_multi_vec_array_vg4_index<string mnemonic, bit T>
+   : sme2_multi_vec_array_vg2_index<0b11, {0b01,?,0b0, T,?}, MatrixOp32,
+                                     ZZ_b_mul_r, ZPR4b8, VectorIndexS, mnemonic> {
+
+  bits<2> i;
+  let Inst{10} = i{1};
+  let Inst{3}  = i{0};
+  let AsmString = !strconcat(mnemonic, "{\t$ZAda[$Rv, $imm3, vgx4], $Zn, $Zm$i}");
+}
+
 // SME2 multi-vec ternary indexed two registers 64-bit
 
 class sme2_multi_vec_array_vg2_index_64b<bits<2> op,
@@ -2608,7 +2630,83 @@ multiclass sme2_multi_vec_array_vg4_index_64b<string mnemonic, bits<3> op,
         (!cast<Instruction>(NAME) MatrixOp64:$ZAda,  MatrixIndexGPR32Op8_11:$Rv, sme_elm_idx0_7:$imm3,
         multi_vector_ty:$Zn, vector_ty:$Zm, VectorIndexD32b_timm:$i1), 0>;
 }
+
+// FMLAL (multiple and indexed vector, FP8 to FP16)
+class sme2_multi_vec_array_vg24_index_16b<bits<2> sz, bit vg4, bits<3> op,
+                                          RegisterOperand multi_vector_ty, string mnemonic>
+    : I<(outs MatrixOp16:$ZAda),
+        (ins MatrixOp16:$_ZAda, MatrixIndexGPR32Op8_11:$Rv, uimm2s2range:$imm2,
+         multi_vector_ty:$Zn, ZPR4b8:$Zm, VectorIndexB:$i),
+         mnemonic, "\t$ZAda[$Rv, $imm2, " # !if(vg4, "vgx4", "vgx2") # "], $Zn, $Zm$i",
+        "", []>, Sched<[]> {
+  bits<4> Zm;
+  bits<2> Rv;
+  bits<4> i;
+  bits<2> imm2;
+  let Inst{31-24} = 0b11000001;
+  let Inst{23-22} = sz;
+  let Inst{21-20} = 0b01;
+  let Inst{19-16} = Zm;
+  let Inst{15}    = vg4;
+  let Inst{14-13} = Rv;
+  let Inst{12}    = op{2};
+  let Inst{11-10} = i{3-2};
+  let Inst{5-4}   = op{1-0};
+  let Inst{3-2}   = i{1-0};
+  let Inst{1-0}   = imm2;
+
+  let Constraints = "$ZAda = $_ZAda";
+}
+
+multiclass sme2_multi_vec_array_vg2_index_16b<string mnemonic, bits<2> sz, bits<3>op> {
+  def NAME : sme2_multi_vec_array_vg24_index_16b<sz, 0b0, op, ZZ_b_mul_r, mnemonic> {
+    bits<4> Zn;
+    let Inst{9-6} = Zn;
+ }
+ def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm2], $Zn, $Zm$i",
+                 (!cast<Instruction>(NAME) MatrixOp16:$ZAda,  MatrixIndexGPR32Op8_11:$Rv,
+                  uimm2s2range:$imm2, ZZ_b_mul_r:$Zn, ZPR4b8:$Zm, VectorIndexB:$i), 0>;
+}
+
+multiclass sme2_multi_vec_array_vg4_index_16b<string mnemonic, bits<2>sz, bits<3>op> {
+  def NAME: sme2_multi_vec_array_vg24_index_16b<sz, 0b1, op, ZZZZ_b_mul_r, mnemonic> {
+    bits<3> Zn;
+    let Inst{9-7} = Zn;
+    let Inst{6}   = 0b0;
+  }
+ def : InstAlias<mnemonic # "\t$ZAda[$Rv, $imm2], $Zn, $Zm$i",
+                 (!cast<Instruction>(NAME) MatrixOp16:$ZAda,  MatrixIndexGPR32Op8_11:$Rv,
+                  uimm2s2range:$imm2, ZZZZ_b_mul_r:$Zn, ZPR4b8:$Zm, VectorIndexB:$i), 0>;
+}
+
 //===----------------------------------------------------------------------===//
+// SME2 multi-vec indexed long long MLA one source 16-bit
+class sme2_mla_ll_array_index_16b<string mnemonic, bits<2> sz,bits<2> op>
+    : I<(outs MatrixOp16:$ZAda),
+        (ins MatrixOp16:$_ZAda, MatrixIndexGPR32Op8_11:$Rv, uimm3s2range:$imm3, ZPR8:$Zn, ZPR4b8:$Zm, VectorIndexB32b_timm:$i),
+        mnemonic, "\t$ZAda[$Rv, $imm3], $Zn, $Zm$i",
+        "", []>, Sched<[]> {
+  bits<4> Zm;
+  bits<2> Rv;
+  bits<4> i;
+  bits<5> Zn;
+  bits<3> imm3;
+  let Inst{31-24} = 0b11000001;
+  let Inst{23-22} = sz;
+  let Inst{21-20} = 0b00;
+  let Inst{19-16} = Zm;
+  let Inst{15}    = i{3};
+  let Inst{14-13} = Rv;
+  let Inst{12}    = op{1};
+  let Inst{11-10} = i{2-1};
+  let Inst{9-5}   = Zn;
+  let Inst{4}     = op{0};
+  let Inst{3}     = i{0};
+  let Inst{2-0}   = imm3;
+
+  let Constraints = "$ZAda = $_ZAda";
+}
+
 // SME2 multi-vec indexed long long MLA one source 32-bit
 class sme2_mla_ll_array_index_32b<string mnemonic, bits<2> sz, bits<3> op>
     : I<(outs MatrixOp32:$ZAda),
@@ -3059,6 +3157,25 @@ class sme2_movt_scalar_to_zt<string mnemonic, bits<7> opc>
   let Inst{4-0}   = Rt;
 }
 
+// SME2 move vec...
[truncated]

Copy link
Contributor

@CarolineConcatto CarolineConcatto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Hassnaa for this patch. I have some comments left.

@@ -285,6 +287,8 @@ inline constexpr ExtensionInfo Extensions[] = {
{"ssve-fp8dot2", AArch64::AEK_SSVE_FP8DOT2, "+ssve-fp8dot2", "-ssve-fp8dot2", FEAT_INIT, "+sme2", 0},
{"fp8dot4", AArch64::AEK_FP8DOT4, "+fp8dot4", "-fp8dot4", FEAT_INIT, "", 0},
{"ssve-fp8dot4", AArch64::AEK_SSVE_FP8DOT4, "+ssve-fp8dot4", "-ssve-fp8dot4", FEAT_INIT, "+sme2", 0},
{"sme-f8f16", AArch64::AEK_SMEF8F16, "+sme-f8f16", "-sme-f8f16", FEAT_INIT, "+sme2,+fp8", 0},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this from here. We can add in the future if needed.
You can remove for sme-f8f16 and sme-f8f32

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it different for this only ? I mean as long as the others exist, so why not we keep this ?

@@ -535,6 +535,12 @@ def FeatureFP8DOT4: SubtargetFeature<"fp8dot4", "HasFP8DOT4", "true",
def FeatureSSVE_FP8DOT4 : SubtargetFeature<"ssve-fp8dot4", "HasSSVE_FP8DOT4", "true",
"Enable SVE2 fp8 4-way dot product instructions (FEAT_SSVE_FP8DOT4)", [FeatureSME2]>;

def FeatureSMEF8F16 : SubtargetFeature<"sme-f8f16", "HasSMEF8F16", "true",
"Enable Scalable Matrix Extension (SME) F8F16 instructions(FEAT_SME_F8F16)", [FeatureSME2, FeatureFP8]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this from here too. We can add later if that is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the "HasSMEF8F16" depend on this feature definition ?
And HasSMEFP8FP16 is required by some instances like FMOPA_MPPZZ_BtoH for example.

@@ -3059,6 +3158,25 @@ class sme2_movt_scalar_to_zt<string mnemonic, bits<7> opc>
let Inst{4-0} = Rt;
}

// SME2 move vector to lookup table
class sme2_movt_zt_to_zt<string mnemonic, bits<7> opc>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this class is here? It is not used in any other place.

@hassnaaHamdi hassnaaHamdi force-pushed the main branch 2 times, most recently from 273cd13 to ab76045 Compare November 2, 2023 18:19
Copy link
Contributor

@CarolineConcatto CarolineConcatto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Hassnaa for the work.
Before merging, remove +sme2 from the diagnostic tests.

This patch adds the feature flags of SME_F8F16 and SME_F8F32,
 and the assembly/disassembly for the following instructions of SME2:
  * SME:
    - FMLAL
    - FMLALL
    - FVDOT
    - FVDOTB
    - FVDOTT
    - FMOPA

Change-Id: Ib5d5c675651633eac8bb382d674c0d7735e1ee5c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants