Skip to content

[mlir][ArmSME] Make use of backend function attributes for enabling ZA storage #71044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 14, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 2, 2023

Previously, we were inserting za.enable/disable intrinsics for functions with the "arm_za" attribute (at the MLIR level), rather than using the backend attributes. This was done to avoid a dependency on the SME ABI functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling a streaming-mode function from another streaming-mode function (both with ZA enabled) would lead to ZA being disabled after returning to the caller (where it should still be enabled). Fixing issues like this would require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute for enabling ZA for an MLIR function. For the integration tests, this requires some way of linking the SME ABI functions. This is done via the %arm_sme_abi_shlib lit substitution. By default, this expands to a stub implementation of the SME ABI functions, but this can be overridden by providing the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable (pointing it at an alternative implementation). For now, the ArmSME integration tests pass with just stubs, as we don't make use of nested ZA-enabled calls.

A future patch may add an option to compiler-rt to build the SME builtins into a standalone shared library to allow easily building/testing with the actual implementation.

@MacDue
Copy link
Member Author

MacDue commented Nov 2, 2023

cc @c-rhodes, @banach-space

Keeping this as a draft for now as I'm a little unhappy with the current solution for the mlir_arm_sme_runtime library.

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2023

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-execution-engine
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Benjamin Maxwell (MacDue)

Changes

Previously, we were inserting za.enable/disable intrinsics for functions with the "arm_za" attribute (at the MLIR level), rather than using the backend attributes. This was done to avoid a dependency on the SME ABI functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling a streaming-mode function from another streaming-mode function (both with ZA enabled) would lead to ZA being disabled after returning to the caller (where it should still be enabled). Fixing issues like this would require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute for enabling ZA for an MLIR function. For the integration tests, this requires some way of linking the SME ABI functions. This is done via the %arm_sme_abi_shlib lit substitution. By default, this expands to a stub implementation of the SME ABI functions, but this can be overridden by providing the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable (pointing it at an alternative implementation). For now, the ArmSME integration tests pass with just stubs, as we don't make use of nested ZA-enabled calls.

A future patch may add an option to compiler-rt to build the SME builtins into a standalone shared library to allow easily building/testing with the actual implementation.


Patch is 31.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71044.diff

25 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (-3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+7-3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+14-6)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+12-13)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+1-54)
  • (added) mlir/lib/ExecutionEngine/ArmSMEStub.cpp (+48)
  • (modified) mlir/lib/ExecutionEngine/CMakeLists.txt (+5)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+6-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+3)
  • (modified) mlir/test/CMakeLists.txt (+6)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+3-3)
  • (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+10-12)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-2)
  • (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (-11)
  • (modified) mlir/test/lit.cfg.py (+18-4)
  • (modified) mlir/test/lit.site.cfg.py.in (+1)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e369ef203ad39d6..9f4ef24366b09db 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -131,7 +131,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
 def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
 
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index ab5c179f2dd7790..95b016e87921a67 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -24,15 +24,19 @@ namespace arm_sme {
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
 // manages it on entry/exit.
-enum class ArmStreaming { Default = 0, Locally = 1 };
+enum class ArmStreamingMode { Default = 0, Locally = 1 };
+
+// TODO: Add other ZA modes.
+// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
+enum class ArmZaMode { Disabled = 0, New = 1 };
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
 
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
-                             const bool enableZA = false);
+createEnableArmStreamingPass(const ArmStreamingMode = ArmStreamingMode::Default,
+                             const ArmZaMode = ArmZaMode::Disabled);
 
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3fa1b43eb9e67e0..2ea5c6947754e65 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -22,19 +22,27 @@ def EnableArmStreaming
   }];
   let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
   let options = [
-    Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
-          /*default=*/"mlir::arm_sme::ArmStreaming::Default",
+    Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
+          /*default=*/"mlir::arm_sme::ArmStreamingMode::Default",
           "Select how streaming-mode is managed at the function-level.",
           [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Default, "default",
 						   "Streaming mode is part of the function interface "
 						   "(ABI), caller manages PSTATE.SM on entry/exit."),
-                clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Locally, "locally",
 						   "Streaming mode is internal to the function, callee "
 						   "manages PSTATE.SM on entry/exit.")
           )}]>,
-    Option<"enableZA", "enable-za", "bool", /*default=*/"false",
-           "Enable ZA storage array.">,
+    Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
+           /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
+           "Select how ZA-storage is managed at the function-level.",
+           [{::llvm::cl::values(
+                clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
+					 	   "ZA storage is disabled."),
+                clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
+					 	   "The function has ZA state. The ZA state is created on entry "
+               "and destroyed on exit.")
+           )}]>
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 638c31b39682ea6..dfc0588e92e44ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1415,6 +1415,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
     OptionalAttr<UnitAttr>:$arm_streaming,
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
+    OptionalAttr<UnitAttr>:$arm_new_za,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 1d3a090e861013b..1b59b6d907235b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -51,26 +51,26 @@ using namespace mlir::arm_sme;
 
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
-static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmNewZAAttr[] = "arm_new_za";
 static constexpr char kEnableArmStreamingIgnoreAttr[] =
     "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
-    this->mode = mode;
-    this->enableZA = enableZA;
+  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+    this->streamingMode = streamingMode;
+    this->zaMode = zaMode;
   }
   void runOnOperation() override {
     if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
       return;
     StringRef attr;
-    switch (mode) {
-    case ArmStreaming::Default:
+    switch (streamingMode) {
+    case ArmStreamingMode::Default:
       attr = kArmStreamingAttr;
       break;
-    case ArmStreaming::Locally:
+    case ArmStreamingMode::Locally:
       attr = kArmLocallyStreamingAttr;
       break;
     }
@@ -80,14 +80,13 @@ struct EnableArmStreamingPass
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (enableZA)
-      getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
+    if (zaMode == ArmZaMode::New)
+      getOperation()->setAttr(kArmNewZAAttr, UnitAttr::get(&getContext()));
   }
 };
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
-                                            const bool enableZA) {
-  return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
+std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
+    const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
+  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d1a54658a595bf3..6078b3f2c5e4708 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -21,33 +21,6 @@ using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
-/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
-/// ops to enable the ZA storage array.
-struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::FuncOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointToStart(&op.front());
-    rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
-
-/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
-/// disable the ZA storage array.
-struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::ReturnOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(op);
-    rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
 
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
@@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
-
-  // Mark 'func.func' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and the first op in the
-  //      function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
-  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
-    if (funcOp.isDeclaration())
-      return true;
-    auto firstOp = funcOp.getBody().front().begin();
-    return !funcOp->hasAttr("arm_za") ||
-           isa<arm_sme::aarch64_sme_za_enable>(firstOp);
-  });
-
-  // Mark 'func.return' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and there's a preceding
-  //      'arm_sme::aarch64_sme_za_disable' intrinsic.
-  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
-    bool hasDisableZA = false;
-    auto funcOp = returnOp->getParentOp();
-    funcOp->walk<WalkOrder::PreOrder>(
-        [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
-    return !funcOp->hasAttr("arm_za") || hasDisableZA;
-  });
 }
 
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
new file mode 100644
index 000000000000000..f9f64ad5e5ac81c
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
@@ -0,0 +1,48 @@
+//===- ArmSMEStub.cpp - ArmSME ABI routine stubs --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/Compiler.h"
+#include <cstdint>
+#include <iostream>
+
+// The actual implementation of these routines is in:
+// compiler-rt/lib/builtins/aarch64/sme-abi.S. These stubs allow the current
+// ArmSME tests to run without depending on compiler-rt. This works as we don't
+// rely on nested ZA-enabled calls at the moment. The use of these stubs can be
+// overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable to
+// a path to an alternate implementation.
+
+extern "C" {
+
+bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
+  // The ArmSME tests are run within an emulator so we assume SME is available.
+  return true;
+}
+
+struct sme_state {
+  int64_t x0;
+  int64_t x1;
+};
+
+sme_state LLVM_ATTRIBUTE_WEAK __arm_sme_state() {
+  std::cerr << "[warning] __arm_sme_state() stubbed!\n";
+  return sme_state{};
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_restore() {
+  std::cerr << "[warning] __arm_tpidr2_restore() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_save() {
+  std::cerr << "[warning] __arm_tpidr2_save() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_za_disable() {
+  std::cerr << "[warning] __arm_za_disable() stubbed!\n";
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdc797763ae3a41..aa8cb9728e8cdd0 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  ArmSMEStub.cpp
   AsyncRuntime.cpp
   CRunnerUtils.cpp
   CudaRuntimeWrappers.cpp
@@ -177,6 +178,10 @@ if(LLVM_ENABLE_PIC)
     target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
   endif()
 
+  add_mlir_library(mlir_arm_sme_abi_stubs
+    SHARED
+    ArmSMEStub.cpp)
+
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
     # custom error message.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index e3562049cd81c76..b4c56f995234cb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1583,7 +1583,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
     // explicit attribute.
     // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
+        attrName == "aarch64_pstate_sm_body" ||
+        attrName == "aarch64_pstate_za_new" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1623,6 +1624,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+
+  if (func->hasFnAttribute("aarch64_pstate_za_new"))
+    funcOp.setArmNewZa(true);
+
   llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
   if (attr.isValid()) {
     MLIRContext *context = funcOp.getContext();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 7312388bc9b4dd2..e6247e12ecb38ac 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (func.getArmNewZa())
+    llvmFunc->addFnAttr("aarch64_pstate_za_new");
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index d81f3c4b1e20c5a..c7b7debffc56ab9 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -28,6 +28,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
       "If arch-specific Arm integration tests run emulated, find Arm native utility libraries in this directory.")
   set(MLIR_GPU_COMPILATION_TEST_FORMAT "fatbin" CACHE STRING
       "The GPU compilation format used by the tests.")
+  set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
+      "Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
   option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
   option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
@@ -139,6 +141,10 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
+if (MLIR_RUN_ARM_SME_TESTS)
+  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_abi_stubs)
+endif()
+
 list(APPEND MLIR_TEST_DEPENDS MLIRUnitTests)
 
 if(LLVM_BUILD_EXAMPLES)
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index e7bbe8c0047687d..2ec6f4090dff0c2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
 // CHECK-LOCALLY-LABEL: @arm_streaming
 // CHECK-LOCALLY-SAME: attributes {arm_locally_streaming}
 // CHECK-ENABLE-ZA-LABEL: @arm_streaming
-// CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
+// CHECK-ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
 func.func @arm_streaming() { return }
 
 // CHECK-LABEL: @not_arm_streaming
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index d415b19f6fa94cf..8631721ef61bc77 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,18 +1,16 @@
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
 
-// CHECK-LABEL: @arm_za
-func.func @arm_za() {
-  // ENABLE-ZA: arm_sme.intr.za.enable
-  // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
-  // ENABLE-ZA-NEXT: return
-  // DISABLE-ZA-NOT: arm_sme.intr.za.enable
-  // DISABLE-ZA-NOT: arm_sme.intr.za.disable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
-  return
-}
+// ENABLE-ZA-LABEL: @arm_new_za
+// ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
+// DISABLE-ZA-LABEL: @arm_new_za
+// DISABLE-ZA-NOT: arm_new_za
+// DISABLE-ZA-SAME: attributes {arm_streaming}
+// NO-ARM-STREAMING-LABEL: @arm_new_za
+// NO-ARM-STREAMING-NOT: arm_new_za
+// NO-ARM-STREAMING-NOT: arm_streaming
+func.func @arm_new_za() { return }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 131cbc05a9857e0..1d9f3977389c850 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -3,14 +3,14 @@
 // RUN:   -test-transform-dialect-erase-schedule \
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
 // RUN: FileCheck %s
 
 func.func @entry() {
diff --git a/mlir/...
[truncated]

@MacDue MacDue requested a review from dcaballe November 6, 2023 12:25
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've left some minor nits, but otherwise LGTM, cheers!

if config.mlir_run_arm_sme_tests:
config.substitutions.append(
(
"%arm_sme_abi_shlib",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: %arm_sme_runtime or %arm_sme_runtime_shlib?

Copy link
Member Author

@MacDue MacDue Nov 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "runtime" implies this is something more substantial, but this library should only ever contain a few SME ABI builtins.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, that's what these are tho, runtime support routines [1]:

Every platform that supports SME must provide the following runtime support routines:

[1] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that "runtime" can be interpreted as something quite substantial. Having said that, I agree with Cullen that we should strive for consistency with the spec. I suggest "arm_sme_support_routines_shlib".

In any case, I wouldn't worry about this particular name too much :) (we can always rename later)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue
Copy link
Member Author

MacDue commented Nov 10, 2023

@c-rhodes I had to update a few recent tests to avoid nested ZA-enabled calls, which only works when using the real ABI runtime. If the subs are used the backend emitted checks end up using some invalid state. See: 84c9d7f

@c-rhodes
Copy link
Collaborator

@c-rhodes I had to update a few recent tests to avoid nested ZA-enabled calls, which only works when using the real ABI runtime. If the subs are used the backend emitted checks end up using some invalid state. See: 84c9d7f

Thanks for heads up. I left a couple of minor nits I spotted but otherwise still LGTM, cheers.

…A storage

Previously, we were inserting za.enable/disable intrinsics for functions
with the "arm_za" attribute (at the MLIR level), rather than using the
backend attributes. This was done to avoid a dependency on the SME ABI
functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling
a streaming-mode function from another streaming-mode function (both
with ZA enabled) would lead to ZA being disabled after returning to the
caller (where it should still be enabled). Fixing issues like this would
require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute
for enabling ZA for a MLIR function. For the integration tests this
requires some way of linking the SME ABI functions. This has been done
by adding a mlir_arm_sme_runtime library, which includes the
implementation from compiler-rt, which can then be linked via the
`-shared-libs` flag.

To build the mlir_arm_sme_runtime the target has to be AArch64, and
the host compiler must be able to assemble SME instructions (this is
supported in recent versions of clang). Note that the host being AArch64
is already assumed by the integration tests linking other runtime
libraries (e.g. mlir_c_runner_utils).
This removes the direct dependency on compiler-rt and instead includes
ABI stub routines in MLIR. Our current tests pass with only stubs, as
we're not making nested ZA-enabled calls. Using these stubs can be
overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache
variable to a path to an alternate implementation.
- ArmSMEStub.cpp -> ArmSMEStubs.cpp
- Move enums to tablegen (to get generated stringification)
- Make enums closer to ACLE
  * Remove "Default" mode
  * Similar naming:
  * ArmZaMode::New -> ArmZaMode::NewZA
  * ArmStreamingMode::Locally -> ArmStreamingMode::StreamingLocally
These required a few little changes to avoid nested ZA enabled calls.
@MacDue MacDue merged commit 783ac3b into llvm:main Nov 14, 2023
@MacDue MacDue deleted the use_sme_abi branch November 14, 2023 12:50
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…A storage (llvm#71044)

Previously, we were inserting za.enable/disable intrinsics for functions
with the "arm_za" attribute (at the MLIR level), rather than using the
backend attributes. This was done to avoid a dependency on the SME ABI
functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling
a streaming-mode function from another streaming-mode function (both
with ZA enabled) would lead to ZA being disabled after returning to the
caller (where it should still be enabled). Fixing issues like this would
require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute
for enabling ZA for an MLIR function. For the integration tests, this
requires some way of linking the SME ABI functions. This is done via the
`%arm_sme_abi_shlib` lit substitution. By default, this expands to a
stub implementation of the SME ABI functions, but this can be overridden
by providing the `ARM_SME_ABI_ROUTINES_SHLIB` CMake cache variable
(pointing it at an alternative implementation). For now, the ArmSME
integration tests pass with just stubs, as we don't make use of nested
ZA-enabled calls.

A future patch may add an option to compiler-rt to build the SME
builtins into a standalone shared library to allow easily
building/testing with the actual implementation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants