Skip to content

[mlir][ArmSME] Make use of backend function attributes for enabling ZA storage #71044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;

def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;

#endif // ARMSME_INTRINSIC_OPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSME)
mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRArmSMETransformsIncGen)

add_mlir_doc(Passes ArmSMEPasses ./ -gen-pass-doc)
13 changes: 4 additions & 9 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.h.inc"
#include "mlir/Pass/Pass.h"

namespace mlir {
Expand All @@ -20,19 +21,13 @@ namespace arm_sme {
//===----------------------------------------------------------------------===//
// The EnableArmStreaming pass.
//===----------------------------------------------------------------------===//
// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
// In a locally streaming function PSTATE.SM is kept internal and the callee
// manages it on entry/exit.
enum class ArmStreaming { Default = 0, Locally = 1 };

#define GEN_PASS_DECL
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"

/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass>
createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
const bool enableZA = false);
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled);

/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
std::unique_ptr<Pass> createTileAllocationPass();
Expand Down
59 changes: 49 additions & 10 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,32 @@
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD

include "mlir/Pass/PassBase.td"
include "mlir/IR/EnumAttr.td"

def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode",
[
I32EnumAttrCase<"Disabled", 0, "disabled">,
// Streaming: Streaming-mode is part of the function interface (ABI).
I32EnumAttrCase<"Streaming", 1, "arm_streaming">,
// StreamingLocally: PSTATE.SM is kept internal and the callee manages it
// on entry/exit.
I32EnumAttrCase<"StreamingLocally", 2, "arm_locally_streaming">,
]>{
let cppNamespace = "mlir::arm_sme";
let genSpecializedAttr = 0;
}

// TODO: Add other ZA modes.
// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
[
I32EnumAttrCase<"Disabled", 0, "disabled">,
// A function's ZA state is created on entry and destroyed on exit.
I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
]>{
let cppNamespace = "mlir::arm_sme";
let genSpecializedAttr = 0;
}

def EnableArmStreaming
: Pass<"enable-arm-streaming", "mlir::func::FuncOp"> {
Expand All @@ -22,19 +48,32 @@ def EnableArmStreaming
}];
let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
let options = [
Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
/*default=*/"mlir::arm_sme::ArmStreaming::Default",
Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
/*default=*/"mlir::arm_sme::ArmStreamingMode::Streaming",
"Select how streaming-mode is managed at the function-level.",
[{::llvm::cl::values(
clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
"Streaming mode is part of the function interface "
"(ABI), caller manages PSTATE.SM on entry/exit."),
clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
"Streaming mode is internal to the function, callee "
"manages PSTATE.SM on entry/exit.")
clEnumValN(mlir::arm_sme::ArmStreamingMode::Disabled,
"disabled", "Streaming mode is disabled."),
clEnumValN(mlir::arm_sme::ArmStreamingMode::Streaming,
"streaming",
"Streaming mode is part of the function interface "
"(ABI), caller manages PSTATE.SM on entry/exit."),
clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingLocally,
"streaming-locally",
"Streaming mode is internal to the function, callee "
"manages PSTATE.SM on entry/exit.")
)}]>,
Option<"enableZA", "enable-za", "bool", /*default=*/"false",
"Enable ZA storage array.">,
Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
/*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
"Select how ZA-storage is managed at the function-level.",
[{::llvm::cl::values(
clEnumValN(mlir::arm_sme::ArmZaMode::Disabled,
"disabled", "ZA storage is disabled."),
clEnumValN(mlir::arm_sme::ArmZaMode::NewZA,
"new-za",
"The function has ZA state. The ZA state is "
"created on entry and destroyed on exit.")
)}]>
];
let dependentDialects = ["func::FuncDialect"];
}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
OptionalAttr<UnitAttr>:$arm_streaming,
OptionalAttr<UnitAttr>:$arm_locally_streaming,
OptionalAttr<UnitAttr>:$arm_new_za,
OptionalAttr<StrAttr>:$section,
OptionalAttr<UnnamedAddr>:$unnamed_addr,
OptionalAttr<I64Attr>:$alignment,
Expand Down
45 changes: 19 additions & 26 deletions mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"

#include "mlir/Dialect/Func/IR/FuncOps.h"

Expand All @@ -48,46 +49,38 @@ namespace arm_sme {

using namespace mlir;
using namespace mlir::arm_sme;
namespace {

static constexpr char kArmStreamingAttr[] = "arm_streaming";
static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
static constexpr char kArmZAAttr[] = "arm_za";
static constexpr char kEnableArmStreamingIgnoreAttr[] =
"enable_arm_streaming_ignore";
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");

namespace {
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
this->mode = mode;
this->enableZA = enableZA;
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
}
void runOnOperation() override {
if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
auto op = getOperation();
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
StringRef attr;
switch (mode) {
case ArmStreaming::Default:
attr = kArmStreamingAttr;
break;
case ArmStreaming::Locally:
attr = kArmLocallyStreamingAttr;
break;
}
getOperation()->setAttr(attr, UnitAttr::get(&getContext()));

auto unitAttr = UnitAttr::get(&getContext());

op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);

// The pass currently only supports enabling ZA when in streaming-mode, but
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
if (enableZA)
getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
if (zaMode != ArmZaMode::Disabled)
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
} // namespace

std::unique_ptr<Pass>
mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
const bool enableZA) {
return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
}
55 changes: 1 addition & 54 deletions mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,6 @@ using namespace mlir;
using namespace mlir::arm_sme;

namespace {
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
/// ops to enable the ZA storage array.
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&op.front());
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};

/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
/// disable the ZA storage array.
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::ReturnOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};

/// Lower 'arm_sme.zero' to SME intrinsics.
///
Expand Down Expand Up @@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();

// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and the first op in the
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (funcOp.isDeclaration())
return true;
auto firstOp = funcOp.getBody().front().begin();
return !funcOp->hasAttr("arm_za") ||
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
});

// Mark 'func.return' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and there's a preceding
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
bool hasDisableZA = false;
auto funcOp = returnOp->getParentOp();
funcOp->walk<WalkOrder::PreOrder>(
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
return !funcOp->hasAttr("arm_za") || hasDisableZA;
});
}

void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
Expand Down
48 changes: 48 additions & 0 deletions mlir/lib/ExecutionEngine/ArmSMEStubs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===- ArmSMEStub.cpp - ArmSME ABI routine stubs --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/Compiler.h"
#include <cstdint>
#include <iostream>

// The actual implementation of these routines is in:
// compiler-rt/lib/builtins/aarch64/sme-abi.S. These stubs allow the current
// ArmSME tests to run without depending on compiler-rt. This works as we don't
// rely on nested ZA-enabled calls at the moment. The use of these stubs can be
// overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable to
// a path to an alternate implementation.

extern "C" {

bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
// The ArmSME tests are run within an emulator so we assume SME is available.
return true;
}

struct sme_state {
int64_t x0;
int64_t x1;
};

sme_state LLVM_ATTRIBUTE_WEAK __arm_sme_state() {
std::cerr << "[warning] __arm_sme_state() stubbed!\n";
return sme_state{};
}

void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_restore() {
std::cerr << "[warning] __arm_tpidr2_restore() stubbed!\n";
}

void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_save() {
std::cerr << "[warning] __arm_tpidr2_save() stubbed!\n";
}

void LLVM_ATTRIBUTE_WEAK __arm_za_disable() {
std::cerr << "[warning] __arm_za_disable() stubbed!\n";
}
}
5 changes: 5 additions & 0 deletions mlir/lib/ExecutionEngine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# is a big dependency which most don't need.

set(LLVM_OPTIONAL_SOURCES
ArmSMEStubs.cpp
AsyncRuntime.cpp
CRunnerUtils.cpp
CudaRuntimeWrappers.cpp
Expand Down Expand Up @@ -177,6 +178,10 @@ if(LLVM_ENABLE_PIC)
target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
endif()

add_mlir_library(mlir_arm_sme_abi_stubs
SHARED
ArmSMEStubs.cpp)

if(MLIR_ENABLE_CUDA_RUNNER)
# Configure CUDA support. Using check_language first allows us to give a
# custom error message.
Expand Down
20 changes: 15 additions & 5 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
funcOp.setMemoryAttr(memAttr);
}

// List of LLVM IR attributes that map to an explicit attribute on the MLIR
// LLVMFuncOp.
static constexpr std::array ExplicitAttributes{
StringLiteral("aarch64_pstate_sm_enabled"),
StringLiteral("aarch64_pstate_sm_body"),
StringLiteral("aarch64_pstate_za_new"),
StringLiteral("vscale_range"),
};

static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
SmallVector<Attribute> passthroughs;
Expand All @@ -1598,11 +1607,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
auto keyAttr = StringAttr::get(context, attrName);

// Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
// explicit attribute.
// Also skip the vscale_range, it is also an explicit attribute.
if (attrName == "aarch64_pstate_sm_enabled" ||
attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
// Skip attributes that map to an explicit attribute on the LLVMFuncOp.
if (llvm::is_contained(ExplicitAttributes, attrName))
continue;

if (attr.isStringAttribute()) {
Expand Down Expand Up @@ -1642,6 +1648,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
funcOp.setArmStreaming(true);
else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
funcOp.setArmLocallyStreaming(true);

if (func->hasFnAttribute("aarch64_pstate_za_new"))
funcOp.setArmNewZa(true);

llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
if (attr.isValid()) {
MLIRContext *context = funcOp.getContext();
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
else if (func.getArmLocallyStreaming())
llvmFunc->addFnAttr("aarch64_pstate_sm_body");

if (func.getArmNewZa())
llvmFunc->addFnAttr("aarch64_pstate_za_new");

if (auto attr = func.getVscaleRange())
llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
getLLVMContext(), attr->getMinRange().getInt(),
Expand Down
Loading