Skip to content

Commit 783ac3b

Browse files
authored
[mlir][ArmSME] Make use of backend function attributes for enabling ZA storage (#71044)
Previously, we were inserting za.enable/disable intrinsics for functions with the "arm_za" attribute (at the MLIR level), rather than using the backend attributes. This was done to avoid a dependency on the SME ABI functions from compiler-rt (which have only recently been implemented). Doing things this way did have correctness issues, for example, calling a streaming-mode function from another streaming-mode function (both with ZA enabled) would lead to ZA being disabled after returning to the caller (where it should still be enabled). Fixing issues like this would require re-doing the ABI work already done in the backend within MLIR. Instead, this patch switches to use the "arm_new_za" (backend) attribute for enabling ZA for an MLIR function. For the integration tests, this requires some way of linking the SME ABI functions. This is done via the `%arm_sme_abi_shlib` lit substitution. By default, this expands to a stub implementation of the SME ABI functions, but this can be overridden by providing the `ARM_SME_ABI_ROUTINES_SHLIB` CMake cache variable (pointing it at an alternative implementation). For now, the ArmSME integration tests pass with just stubs, as we don't make use of nested ZA-enabled calls. A future patch may add an option to compiler-rt to build the SME builtins into a standalone shared library to allow easily building/testing with the actual implementation.
1 parent 6684541 commit 783ac3b

29 files changed

+228
-168
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
161161
def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
162162
def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
163163

164-
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
165-
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
166-
167164
#endif // ARMSME_INTRINSIC_OPS
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
set(LLVM_TARGET_DEFINITIONS Passes.td)
22
mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSME)
3+
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
4+
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
35
add_public_tablegen_target(MLIRArmSMETransformsIncGen)
46

57
add_mlir_doc(Passes ArmSMEPasses ./ -gen-pass-doc)

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
1111

1212
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13+
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.h.inc"
1314
#include "mlir/Pass/Pass.h"
1415

1516
namespace mlir {
@@ -20,19 +21,13 @@ namespace arm_sme {
2021
//===----------------------------------------------------------------------===//
2122
// The EnableArmStreaming pass.
2223
//===----------------------------------------------------------------------===//
23-
// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
24-
// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
25-
// In a locally streaming function PSTATE.SM is kept internal and the callee
26-
// manages it on entry/exit.
27-
enum class ArmStreaming { Default = 0, Locally = 1 };
28-
2924
#define GEN_PASS_DECL
3025
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
3126

3227
/// Pass to enable Armv9 Streaming SVE mode.
33-
std::unique_ptr<Pass>
34-
createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
35-
const bool enableZA = false);
28+
std::unique_ptr<Pass> createEnableArmStreamingPass(
29+
const ArmStreamingMode = ArmStreamingMode::Streaming,
30+
const ArmZaMode = ArmZaMode::Disabled);
3631

3732
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
3833
std::unique_ptr<Pass> createTileAllocationPass();

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,32 @@
1010
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
1111

1212
include "mlir/Pass/PassBase.td"
13+
include "mlir/IR/EnumAttr.td"
14+
15+
def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode",
16+
[
17+
I32EnumAttrCase<"Disabled", 0, "disabled">,
18+
// Streaming: Streaming-mode is part of the function interface (ABI).
19+
I32EnumAttrCase<"Streaming", 1, "arm_streaming">,
20+
// StreamingLocally: PSTATE.SM is kept internal and the callee manages it
21+
// on entry/exit.
22+
I32EnumAttrCase<"StreamingLocally", 2, "arm_locally_streaming">,
23+
]>{
24+
let cppNamespace = "mlir::arm_sme";
25+
let genSpecializedAttr = 0;
26+
}
27+
28+
// TODO: Add other ZA modes.
29+
// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
30+
def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
31+
[
32+
I32EnumAttrCase<"Disabled", 0, "disabled">,
33+
// A function's ZA state is created on entry and destroyed on exit.
34+
I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
35+
]>{
36+
let cppNamespace = "mlir::arm_sme";
37+
let genSpecializedAttr = 0;
38+
}
1339

1440
def EnableArmStreaming
1541
: Pass<"enable-arm-streaming", "mlir::func::FuncOp"> {
@@ -22,19 +48,32 @@ def EnableArmStreaming
2248
}];
2349
let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
2450
let options = [
25-
Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
26-
/*default=*/"mlir::arm_sme::ArmStreaming::Default",
51+
Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
52+
/*default=*/"mlir::arm_sme::ArmStreamingMode::Streaming",
2753
"Select how streaming-mode is managed at the function-level.",
2854
[{::llvm::cl::values(
29-
clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
30-
"Streaming mode is part of the function interface "
31-
"(ABI), caller manages PSTATE.SM on entry/exit."),
32-
clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
33-
"Streaming mode is internal to the function, callee "
34-
"manages PSTATE.SM on entry/exit.")
55+
clEnumValN(mlir::arm_sme::ArmStreamingMode::Disabled,
56+
"disabled", "Streaming mode is disabled."),
57+
clEnumValN(mlir::arm_sme::ArmStreamingMode::Streaming,
58+
"streaming",
59+
"Streaming mode is part of the function interface "
60+
"(ABI), caller manages PSTATE.SM on entry/exit."),
61+
clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingLocally,
62+
"streaming-locally",
63+
"Streaming mode is internal to the function, callee "
64+
"manages PSTATE.SM on entry/exit.")
3565
)}]>,
36-
Option<"enableZA", "enable-za", "bool", /*default=*/"false",
37-
"Enable ZA storage array.">,
66+
Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
67+
/*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
68+
"Select how ZA-storage is managed at the function-level.",
69+
[{::llvm::cl::values(
70+
clEnumValN(mlir::arm_sme::ArmZaMode::Disabled,
71+
"disabled", "ZA storage is disabled."),
72+
clEnumValN(mlir::arm_sme::ArmZaMode::NewZA,
73+
"new-za",
74+
"The function has ZA state. The ZA state is "
75+
"created on entry and destroyed on exit.")
76+
)}]>
3877
];
3978
let dependentDialects = ["func::FuncDialect"];
4079
}

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
13871387
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
13881388
OptionalAttr<UnitAttr>:$arm_streaming,
13891389
OptionalAttr<UnitAttr>:$arm_locally_streaming,
1390+
OptionalAttr<UnitAttr>:$arm_new_za,
13901391
OptionalAttr<StrAttr>:$section,
13911392
OptionalAttr<UnnamedAddr>:$unnamed_addr,
13921393
OptionalAttr<I64Attr>:$alignment,

mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
//===----------------------------------------------------------------------===//
3535

3636
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
37+
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
3738

3839
#include "mlir/Dialect/Func/IR/FuncOps.h"
3940

@@ -48,46 +49,38 @@ namespace arm_sme {
4849

4950
using namespace mlir;
5051
using namespace mlir::arm_sme;
52+
namespace {
5153

52-
static constexpr char kArmStreamingAttr[] = "arm_streaming";
53-
static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
54-
static constexpr char kArmZAAttr[] = "arm_za";
55-
static constexpr char kEnableArmStreamingIgnoreAttr[] =
56-
"enable_arm_streaming_ignore";
54+
constexpr StringLiteral
55+
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
5756

58-
namespace {
5957
struct EnableArmStreamingPass
6058
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
61-
EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
62-
this->mode = mode;
63-
this->enableZA = enableZA;
59+
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
60+
this->streamingMode = streamingMode;
61+
this->zaMode = zaMode;
6462
}
6563
void runOnOperation() override {
66-
if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
64+
auto op = getOperation();
65+
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
66+
streamingMode == ArmStreamingMode::Disabled)
6767
return;
68-
StringRef attr;
69-
switch (mode) {
70-
case ArmStreaming::Default:
71-
attr = kArmStreamingAttr;
72-
break;
73-
case ArmStreaming::Locally:
74-
attr = kArmLocallyStreamingAttr;
75-
break;
76-
}
77-
getOperation()->setAttr(attr, UnitAttr::get(&getContext()));
68+
69+
auto unitAttr = UnitAttr::get(&getContext());
70+
71+
op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
7872

7973
// The pass currently only supports enabling ZA when in streaming-mode, but
8074
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
8175
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
8276
// supporting this later.
83-
if (enableZA)
84-
getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
77+
if (zaMode != ArmZaMode::Disabled)
78+
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
8579
}
8680
};
8781
} // namespace
8882

89-
std::unique_ptr<Pass>
90-
mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
91-
const bool enableZA) {
92-
return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
83+
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
84+
const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
85+
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
9386
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,6 @@ using namespace mlir;
2121
using namespace mlir::arm_sme;
2222

2323
namespace {
24-
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
25-
/// ops to enable the ZA storage array.
26-
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
27-
using OpRewritePattern::OpRewritePattern;
28-
LogicalResult matchAndRewrite(func::FuncOp op,
29-
PatternRewriter &rewriter) const final {
30-
OpBuilder::InsertionGuard g(rewriter);
31-
rewriter.setInsertionPointToStart(&op.front());
32-
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
33-
rewriter.updateRootInPlace(op, [] {});
34-
return success();
35-
}
36-
};
37-
38-
/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
39-
/// disable the ZA storage array.
40-
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
41-
using OpRewritePattern::OpRewritePattern;
42-
LogicalResult matchAndRewrite(func::ReturnOp op,
43-
PatternRewriter &rewriter) const final {
44-
OpBuilder::InsertionGuard g(rewriter);
45-
rewriter.setInsertionPoint(op);
46-
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
47-
rewriter.updateRootInPlace(op, [] {});
48-
return success();
49-
}
50-
};
5124

5225
/// Lower 'arm_sme.zero' to SME intrinsics.
5326
///
@@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
678651
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
679652
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
680653
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
681-
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
682-
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
654+
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
683655
target.addLegalOp<GetTileID>();
684656
target.addIllegalOp<vector::OuterProductOp>();
685-
686-
// Mark 'func.func' ops as legal if either:
687-
// 1. no 'arm_za' function attribute is present.
688-
// 2. the 'arm_za' function attribute is present and the first op in the
689-
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
690-
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
691-
if (funcOp.isDeclaration())
692-
return true;
693-
auto firstOp = funcOp.getBody().front().begin();
694-
return !funcOp->hasAttr("arm_za") ||
695-
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
696-
});
697-
698-
// Mark 'func.return' ops as legal if either:
699-
// 1. no 'arm_za' function attribute is present.
700-
// 2. the 'arm_za' function attribute is present and there's a preceding
701-
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
702-
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
703-
bool hasDisableZA = false;
704-
auto funcOp = returnOp->getParentOp();
705-
funcOp->walk<WalkOrder::PreOrder>(
706-
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
707-
return !funcOp->hasAttr("arm_za") || hasDisableZA;
708-
});
709657
}
710658

711659
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
712660
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
713-
patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
714661
patterns.add<
715662
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
716663
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- ArmSMEStub.cpp - ArmSME ABI routine stubs --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Support/Compiler.h"
10+
#include <cstdint>
11+
#include <iostream>
12+
13+
// The actual implementation of these routines is in:
14+
// compiler-rt/lib/builtins/aarch64/sme-abi.S. These stubs allow the current
15+
// ArmSME tests to run without depending on compiler-rt. This works as we don't
16+
// rely on nested ZA-enabled calls at the moment. The use of these stubs can be
17+
// overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable to
18+
// a path to an alternate implementation.
19+
20+
extern "C" {
21+
22+
bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
23+
// The ArmSME tests are run within an emulator so we assume SME is available.
24+
return true;
25+
}
26+
27+
struct sme_state {
28+
int64_t x0;
29+
int64_t x1;
30+
};
31+
32+
sme_state LLVM_ATTRIBUTE_WEAK __arm_sme_state() {
33+
std::cerr << "[warning] __arm_sme_state() stubbed!\n";
34+
return sme_state{};
35+
}
36+
37+
void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_restore() {
38+
std::cerr << "[warning] __arm_tpidr2_restore() stubbed!\n";
39+
}
40+
41+
void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_save() {
42+
std::cerr << "[warning] __arm_tpidr2_save() stubbed!\n";
43+
}
44+
45+
void LLVM_ATTRIBUTE_WEAK __arm_za_disable() {
46+
std::cerr << "[warning] __arm_za_disable() stubbed!\n";
47+
}
48+
}

mlir/lib/ExecutionEngine/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# is a big dependency which most don't need.
33

44
set(LLVM_OPTIONAL_SOURCES
5+
ArmSMEStubs.cpp
56
AsyncRuntime.cpp
67
CRunnerUtils.cpp
78
CudaRuntimeWrappers.cpp
@@ -177,6 +178,10 @@ if(LLVM_ENABLE_PIC)
177178
target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
178179
endif()
179180

181+
add_mlir_library(mlir_arm_sme_abi_stubs
182+
SHARED
183+
ArmSMEStubs.cpp)
184+
180185
if(MLIR_ENABLE_CUDA_RUNNER)
181186
# Configure CUDA support. Using check_language first allows us to give a
182187
# custom error message.

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
15731573
funcOp.setMemoryAttr(memAttr);
15741574
}
15751575

1576+
// List of LLVM IR attributes that map to an explicit attribute on the MLIR
1577+
// LLVMFuncOp.
1578+
static constexpr std::array ExplicitAttributes{
1579+
StringLiteral("aarch64_pstate_sm_enabled"),
1580+
StringLiteral("aarch64_pstate_sm_body"),
1581+
StringLiteral("aarch64_pstate_za_new"),
1582+
StringLiteral("vscale_range"),
1583+
};
1584+
15761585
static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
15771586
MLIRContext *context = funcOp.getContext();
15781587
SmallVector<Attribute> passthroughs;
@@ -1598,11 +1607,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
15981607
attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
15991608
auto keyAttr = StringAttr::get(context, attrName);
16001609

1601-
// Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
1602-
// explicit attribute.
1603-
// Also skip the vscale_range, it is also an explicit attribute.
1604-
if (attrName == "aarch64_pstate_sm_enabled" ||
1605-
attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
1610+
// Skip attributes that map to an explicit attribute on the LLVMFuncOp.
1611+
if (llvm::is_contained(ExplicitAttributes, attrName))
16061612
continue;
16071613

16081614
if (attr.isStringAttribute()) {
@@ -1642,6 +1648,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
16421648
funcOp.setArmStreaming(true);
16431649
else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
16441650
funcOp.setArmLocallyStreaming(true);
1651+
1652+
if (func->hasFnAttribute("aarch64_pstate_za_new"))
1653+
funcOp.setArmNewZa(true);
1654+
16451655
llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
16461656
if (attr.isValid()) {
16471657
MLIRContext *context = funcOp.getContext();

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
890890
else if (func.getArmLocallyStreaming())
891891
llvmFunc->addFnAttr("aarch64_pstate_sm_body");
892892

893+
if (func.getArmNewZa())
894+
llvmFunc->addFnAttr("aarch64_pstate_za_new");
895+
893896
if (auto attr = func.getVscaleRange())
894897
llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
895898
getLLVMContext(), attr->getMinRange().getInt(),

0 commit comments

Comments
 (0)