-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Add option to only enable streaming mode/ZA if required #73931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds a `only-if-required-by-ops` flag to the `enable-arm-streaming` pass. This flag defaults to `false` (which preserves the original behaviour), however, if set to `true` the pass will only add the selected ZA/streaming mode to functions that contain ops that implement `ArmSMETileOpInterface`. This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes.
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis adds a This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes. Full diff: https://github.com/llvm/llvm-project/pull/73931.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 11a7385fe311dd3..21a97e9cbc794c9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,7 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
- const ArmZaMode = ArmZaMode::Disabled);
+ const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3253b47e62abddb..7b9c74e0b8f60e7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -73,7 +73,11 @@ def EnableArmStreaming
"new-za",
"The function has ZA state. The ZA state is "
"created on entry and destroyed on exit.")
- )}]>
+ )}]>,
+ Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+ /*default=*/"false",
+ "Only apply the selected streaming/ZA modes if the function "
+ " contains ops that require them.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index c3a1a1c9a3fb49e..79a6caffb6ee0bf 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -33,6 +33,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
@@ -56,12 +57,28 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
- EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+ EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
+ bool onlyIfRequiredByOps) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
+ this->onlyIfRequiredByOps = onlyIfRequiredByOps;
}
void runOnOperation() override {
auto op = getOperation();
+
+ if (onlyIfRequiredByOps) {
+ bool foundTileOp = false;
+ op.walk([&](Operation *op) {
+ if (llvm::isa<ArmSMETileOpInterface>(op)) {
+ foundTileOp = true;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (!foundTileOp)
+ return;
+ }
+
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
@@ -81,6 +98,8 @@ struct EnableArmStreamingPass
} // namespace
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
- const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
- return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
+ const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
+ bool onlyIfRequiredByOps) {
+ return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
+ onlyIfRequiredByOps);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 70119b08c3e91aa..b1188acbc0b2d74 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -17,3 +18,18 @@ func.func @arm_streaming() { return }
// CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
+
+// CHECK-LABEL: @requires_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @requires_arm_streaming
+// IF-REQUIRED-SAME: attributes {arm_streaming}
+func.func @requires_arm_streaming() {
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+ return
+}
+
+// CHECK-LABEL: @does_not_require_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @does_not_require_arm_streaming
+// IF-REQUIRED-NOT: arm_streaming
+func.func @does_not_require_arm_streaming() { return }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we can look at making this the default behaviour, but for now this a good first step in that direction. LGTM, cheers!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM :)
This adds a
only-if-required-by-ops
flag to theenable-arm-streaming
pass. This flag defaults tofalse
(which preserves the original behaviour), however, if set totrue
the pass will only add the selected ZA/streaming mode to functions that contain ops that implementArmSMETileOpInterface
.This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes.