Skip to content

Commit 65305ae

Browse files
committed
[mlir][ArmSME] Insert intrinsics to enable/disable ZA
This patch adds two LLVM intrinsics to the ArmSME dialect: * llvm.aarch64.sme.za.enable * llvm.aarch64.sme.za.disable for enabling the ZA storage array [1], as well as patterns for inserting them during legalization to LLVM at the start and end of functions if the function has the 'arm_za' attribute (D152695). In the future ZA should probably be automatically enabled/disabled when lowering from vector to SME, but this should be sufficient for now at least until we have patterns lowering to SME instructions that use ZA. N.B. The backend function attribute 'aarch64_pstate_za_new' can be used manage ZA state (as was originally tried in D152694), but it emits calls to the following SME support routines [2] for the lazy-save mechanism [3]: * __arm_tpidr2_restore * __arm_tpidr2_save These will soon be added to compiler-rt but there's currently no public implementation, and using this attribute would introduce an MLIR dependency on compiler-rt. Furthermore, this mechanism is for routines with ZA enabled calling other routines with it also enabled. We can choose not to enable ZA in the compiler when this is case. Depends on D152695 [1] https://developer.arm.com/documentation/ddi0616/aa [2] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#the-za-lazy-saving-scheme Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D153050
1 parent f8e67c4 commit 65305ae

File tree

9 files changed

+153
-0
lines changed

9 files changed

+153
-0
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm", "ModuleOp"> {
10921092
"bool", /*default=*/"false",
10931093
"Enables the use of ArmSVE dialect while lowering the vector "
10941094
"dialect.">,
1095+
Option<"armSME", "enable-arm-sme",
1096+
"bool", /*default=*/"false",
1097+
"Enables the use of ArmSME dialect while lowering the vector "
1098+
"dialect.">,
10951099
Option<"x86Vector", "enable-x86vector",
10961100
"bool", /*default=*/"false",
10971101
"Enables the use of X86Vector dialect while lowering the vector "

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
119119
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
120120
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
121121

122+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
123+
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
124+
122125
#endif // ARMSME_OPS
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
10+
#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
11+
12+
namespace mlir {
13+
14+
class LLVMConversionTarget;
15+
class LLVMTypeConverter;
16+
class RewritePatternSet;
17+
18+
/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
19+
/// intrinsics.
20+
void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
21+
RewritePatternSet &patterns);
22+
23+
/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
24+
/// intrinsics.
25+
void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H

mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
1515
LINK_LIBS PUBLIC
1616
MLIRArithDialect
1717
MLIRArmNeonDialect
18+
MLIRArmSMEDialect
19+
MLIRArmSMETransforms
1820
MLIRArmSVEDialect
1921
MLIRArmSVETransforms
2022
MLIRAMXDialect

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "mlir/Dialect/AMX/Transforms.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
18+
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
1719
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
1820
#include "mlir/Dialect/ArmSVE/Transforms.h"
1921
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -49,6 +51,8 @@ struct LowerVectorToLLVMPass
4951
registry.insert<arm_neon::ArmNeonDialect>();
5052
if (armSVE)
5153
registry.insert<arm_sve::ArmSVEDialect>();
54+
if (armSME)
55+
registry.insert<arm_sme::ArmSMEDialect>();
5256
if (amx)
5357
registry.insert<amx::AMXDialect>();
5458
if (x86Vector)
@@ -102,6 +106,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
102106
configureArmSVELegalizeForExportTarget(target);
103107
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
104108
}
109+
if (armSME) {
110+
configureArmSMELegalizeForExportTarget(target);
111+
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
112+
}
105113
if (amx) {
106114
configureAMXLegalizeForExportTarget(target);
107115
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRArmSMETransforms
22
EnableArmStreaming.cpp
3+
LegalizeForLLVMExport.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -8,6 +9,8 @@ add_mlir_dialect_library(MLIRArmSMETransforms
89
MLIRArmSMETransformsIncGen
910

1011
LINK_LIBS PUBLIC
12+
MLIRArmSMEDialect
1113
MLIRFuncDialect
14+
MLIRLLVMCommonConversion
1215
MLIRPass
1316
)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
11+
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::arm_sme;
16+
17+
namespace {
18+
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
19+
/// ops to enable the ZA storage array.
20+
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
21+
using OpRewritePattern::OpRewritePattern;
22+
LogicalResult matchAndRewrite(func::FuncOp op,
23+
PatternRewriter &rewriter) const final {
24+
OpBuilder::InsertionGuard g(rewriter);
25+
rewriter.setInsertionPointToStart(&op.front());
26+
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
27+
rewriter.updateRootInPlace(op, [] {});
28+
return success();
29+
}
30+
};
31+
32+
/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
33+
/// disable the ZA storage array.
34+
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
35+
using OpRewritePattern::OpRewritePattern;
36+
LogicalResult matchAndRewrite(func::ReturnOp op,
37+
PatternRewriter &rewriter) const final {
38+
OpBuilder::InsertionGuard g(rewriter);
39+
rewriter.setInsertionPoint(op);
40+
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
41+
rewriter.updateRootInPlace(op, [] {});
42+
return success();
43+
}
44+
};
45+
} // namespace
46+
47+
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
48+
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
49+
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
50+
}
51+
52+
void mlir::configureArmSMELegalizeForExportTarget(
53+
LLVMConversionTarget &target) {
54+
target.addLegalOp<arm_sme::aarch64_sme_za_enable,
55+
arm_sme::aarch64_sme_za_disable>();
56+
57+
// Mark 'func.func' ops as legal if either:
58+
// 1. no 'arm_za' function attribute is present.
59+
// 2. the 'arm_za' function attribute is present and the first op in the
60+
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
61+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
62+
auto firstOp = funcOp.getBody().front().begin();
63+
return !funcOp->hasAttr("arm_za") ||
64+
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
65+
});
66+
67+
// Mark 'func.return' ops as legal if either:
68+
// 1. no 'arm_za' function attribute is present.
69+
// 2. the 'arm_za' function attribute is present and there's a preceding
70+
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
71+
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
72+
bool hasDisableZA = false;
73+
auto funcOp = returnOp->getParentOp();
74+
funcOp->walk<WalkOrder::PreOrder>(
75+
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
76+
return !funcOp->hasAttr("arm_za") || hasDisableZA;
77+
});
78+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
2+
// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
3+
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
4+
5+
// CHECK-LABEL: @arm_za
6+
func.func @arm_za() {
7+
// ENABLE-ZA: arm_sme.intr.za.enable
8+
// ENABLE-ZA-NEXT: arm_sme.intr.za.disable
9+
// ENABLE-ZA-NEXT: return
10+
// DISABLE-ZA-NOT: arm_sme.intr.za.enable
11+
// DISABLE-ZA-NOT: arm_sme.intr.za.disable
12+
// NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
13+
// NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
14+
return
15+
}

mlir/test/Target/LLVMIR/arm-sme.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,14 @@ llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>,
223223
(vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
224224
llvm.return
225225
}
226+
227+
// -----
228+
229+
// CHECK-LABEL: @arm_sme_toggle_za
230+
llvm.func @arm_sme_toggle_za() {
231+
// CHECK: call void @llvm.aarch64.sme.za.enable()
232+
"arm_sme.intr.za.enable"() : () -> ()
233+
// CHECK: call void @llvm.aarch64.sme.za.disable()
234+
"arm_sme.intr.za.disable"() : () -> ()
235+
llvm.return
236+
}

0 commit comments

Comments
 (0)