Skip to content

[mlir][ArmSME] Add option to only enable streaming mode/ZA if required #73931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled);
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);

/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def EnableArmStreaming
"new-za",
"The function has ZA state. The ZA state is "
"created on entry and destroyed on exit.")
)}]>
)}]>,
Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function "
" contains ops that require them.">
];
let dependentDialects = ["func::FuncDialect"];
}
Expand Down
25 changes: 22 additions & 3 deletions mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"

Expand All @@ -56,12 +57,28 @@ constexpr StringLiteral

struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->onlyIfRequiredByOps = onlyIfRequiredByOps;
}
void runOnOperation() override {
auto op = getOperation();

if (onlyIfRequiredByOps) {
bool foundTileOp = false;
op.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!foundTileOp)
return;
}

if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
Expand All @@ -81,6 +98,8 @@ struct EnableArmStreamingPass
} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
bool onlyIfRequiredByOps) {
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
onlyIfRequiredByOps);
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED

// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
Expand All @@ -17,3 +18,18 @@ func.func @arm_streaming() { return }
// CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }

// CHECK-LABEL: @requires_arm_streaming
// CHECK-SAME: attributes {arm_streaming}
// IF-REQUIRED: @requires_arm_streaming
// IF-REQUIRED-SAME: attributes {arm_streaming}
func.func @requires_arm_streaming() {
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}

// CHECK-LABEL: @does_not_require_arm_streaming
// CHECK-SAME: attributes {arm_streaming}
// IF-REQUIRED: @does_not_require_arm_streaming
// IF-REQUIRED-NOT: arm_streaming
func.func @does_not_require_arm_streaming() { return }