Skip to content

Commit 7a2fdc6

Browse files
gaofangfrankpsoni2628
authored andcommitted
[mlir][ArmSME] Dialect and Intrinsic Op Definition
This patch creates the ArmSME dialect, and provides the intrinsic op definition necessary for lowering to LLVM IR. This will cover most instructions interacting with the ZA tile register, not covering SME2 instructions. Source: https://developer.arm.com/documentation/ddi0616/latest Reviewed By: awarzynski, c-rhodes Differential Revision: https://reviews.llvm.org/D152878
1 parent 3d5cf0d commit 7a2fdc6

File tree

16 files changed

+542
-0
lines changed

16 files changed

+542
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
add_subdirectory(IR)
12
add_subdirectory(Transforms)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the Target dialect for ArmSME in MLIR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_ARMSME_IR_ARMSME_H
14+
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
15+
16+
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/IR/Dialect.h"
19+
#include "mlir/IR/OpDefinition.h"
20+
#include "mlir/Interfaces/SideEffectInterfaces.h"
21+
22+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
23+
24+
#define GET_OP_CLASSES
25+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
26+
27+
#endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines the ArmSME dialect and contains intrinsic ops to lower to
10+
// LLVM IR.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef ARMSME_OPS
15+
#define ARMSME_OPS
16+
17+
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+
20+
//===----------------------------------------------------------------------===//
21+
// ArmSME dialect definition
22+
//===----------------------------------------------------------------------===//
23+
24+
def ArmSME_Dialect : Dialect {
25+
let name = "arm_sme";
26+
let cppNamespace = "::mlir::arm_sme";
27+
let summary = "Basic dialect to target Arm SME architectures";
28+
let description = [{
29+
This dialect contains the definitions necessary to target Arm SME
30+
scalable matrix operations.
31+
32+
Sources:
33+
https://developer.arm.com/documentation/ddi0616
34+
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
35+
}];
36+
}
37+
38+
//===----------------------------------------------------------------------===//
39+
// ArmSME Intrinsic op definitions
40+
//===----------------------------------------------------------------------===//
41+
42+
def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
43+
def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
44+
[I8, I16, BF16, F16, F32, F64]>;
45+
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
46+
47+
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
48+
list<Trait> traits = []>
49+
: LLVM_IntrOpBase<
50+
/*Dialect dialect=*/ArmSME_Dialect,
51+
/*string opName=*/"intr." # mnemonic,
52+
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
53+
/*list<int> overloadedResults=*/[],
54+
/*list<int> overloadedOperands=*/overloadedOperands,
55+
/*list<Trait> traits=*/traits,
56+
/*int numResults=*/0>;
57+
58+
// Zero
59+
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
60+
Arguments<(ins Arg<I32, "Tile mask">)>;
61+
62+
// MOP's
63+
class ArmSME_IntrMopOverloadedOp<string mnemonic>
64+
: ArmSME_IntrOp<mnemonic, [4]>,
65+
Arguments<(ins Arg<I32, "Virtual tile ID">,
66+
Arg<MOPPredicate, "LHS predicate">,
67+
Arg<MOPPredicate, "RHS predicate">,
68+
Arg<MOPVector, "LHS vector operand">,
69+
Arg<MOPVector, "RHS vector operand">)>;
70+
71+
def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
72+
def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
73+
def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
74+
def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
75+
def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
76+
def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
77+
def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
78+
def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
79+
def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
80+
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
81+
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
82+
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
83+
84+
// Loads
85+
class ArmSME_IntrLoadOp<string mnemonic>
86+
: ArmSME_IntrOp<mnemonic>,
87+
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
88+
Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
89+
Arg<I32, "Virtual tile ID">,
90+
Arg<I32, "Tile slice">)>;
91+
92+
def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
93+
def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
94+
def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
95+
def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
96+
def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
97+
def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
98+
def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
99+
def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
100+
def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
101+
def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
102+
103+
// Stores
104+
class ArmSME_IntrStoreOp<string mnemonic>
105+
: ArmSME_IntrOp<mnemonic>,
106+
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
107+
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
108+
Arg<I32, "Virtual tile ID">,
109+
Arg<I32, "Tile slice">)>;
110+
111+
def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
112+
def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
113+
def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
114+
def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
115+
def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
116+
def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
117+
def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
118+
def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
119+
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
120+
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
121+
122+
#endif // ARMSME_OPS
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_mlir_dialect(ArmSME arm_sme ArmSME)
2+
add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
3+
4+
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
5+
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
6+
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
2424
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
2525
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
26+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
2627
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
2728
#include "mlir/Dialect/Async/IR/Async.h"
2829
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
117118
pdl_interp::PDLInterpDialect,
118119
quant::QuantizationDialect,
119120
spirv::SPIRVDialect,
121+
arm_sme::ArmSMEDialect,
120122
arm_sve::ArmSVEDialect,
121123
vector::VectorDialect,
122124
NVVM::NVVMDialect,

mlir/include/mlir/Target/LLVMIR/Dialect/All.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
1818
#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
19+
#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
1920
#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
2021
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
2122
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -35,6 +36,7 @@ class DialectRegistry;
3536
static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
3637
registerArmNeonDialectTranslation(registry);
3738
registerAMXDialectTranslation(registry);
39+
registerArmSMEDialectTranslation(registry);
3840
registerArmSVEDialectTranslation(registry);
3941
registerBuiltinDialectTranslation(registry);
4042
registerGPUDialectTranslation(registry);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This provides registration calls for ArmSME dialect to LLVM IR translation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
14+
#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
15+
16+
namespace mlir {
17+
18+
class DialectRegistry;
19+
class MLIRContext;
20+
21+
/// Register the ArmSME dialect and the translation from it to the LLVM IR in
22+
/// the given registry;
23+
void registerArmSMEDialectTranslation(DialectRegistry &registry);
24+
25+
/// Register the ArmSME dialect and the translation from it in the registry
26+
/// associated with the given context.
27+
void registerArmSMEDialectTranslation(MLIRContext &context);
28+
29+
} // namespace mlir
30+
31+
#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
add_subdirectory(IR)
12
add_subdirectory(Transforms)

mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the ArmSME dialect and its operations.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::arm_sme;
18+
19+
//===----------------------------------------------------------------------===//
20+
// Tablegen Definitions
21+
//===----------------------------------------------------------------------===//
22+
23+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
24+
25+
#define GET_OP_CLASSES
26+
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
27+
28+
#define GET_TYPEDEF_CLASSES
29+
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
30+
31+
void ArmSMEDialect::initialize() {
32+
addOperations<
33+
#define GET_OP_LIST
34+
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
35+
>();
36+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_dialect_library(MLIRArmSMEDialect
2+
ArmSME.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
6+
7+
DEPENDS
8+
MLIRArmSMEIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRLLVMDialect
13+
MLIRSideEffectInterfaces
14+
)

mlir/lib/Target/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
4646

4747
LINK_LIBS PUBLIC
4848
MLIRArmNeonToLLVMIRTranslation
49+
MLIRArmSMEToLLVMIRTranslation
4950
MLIRArmSVEToLLVMIRTranslation
5051
MLIRAMXToLLVMIRTranslation
5152
MLIRBuiltinToLLVMIRTranslation
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a translation between the ArmSME dialect and LLVM IR.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
14+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
15+
#include "mlir/IR/Operation.h"
16+
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
17+
18+
#include "llvm/IR/IRBuilder.h"
19+
#include "llvm/IR/IntrinsicsAArch64.h"
20+
21+
using namespace mlir;
22+
using namespace mlir::LLVM;
23+
24+
namespace {
25+
/// Implementation of the dialect interface that converts operations belonging
26+
/// to the ArmSME dialect to LLVM IR.
27+
class ArmSMEDialectLLVMIRTranslationInterface
28+
: public LLVMTranslationDialectInterface {
29+
public:
30+
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
31+
32+
/// Translates the given operation to LLVM IR using the provided IR builder
33+
/// and saving the state in `moduleTranslation`.
34+
LogicalResult
35+
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
36+
LLVM::ModuleTranslation &moduleTranslation) const final {
37+
Operation &opInst = *op;
38+
#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc"
39+
40+
return failure();
41+
}
42+
};
43+
} // namespace
44+
45+
void mlir::registerArmSMEDialectTranslation(DialectRegistry &registry) {
46+
registry.insert<arm_sme::ArmSMEDialect>();
47+
registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) {
48+
dialect->addInterfaces<ArmSMEDialectLLVMIRTranslationInterface>();
49+
});
50+
}
51+
52+
void mlir::registerArmSMEDialectTranslation(MLIRContext &context) {
53+
DialectRegistry registry;
54+
registerArmSMEDialectTranslation(registry);
55+
context.appendDialectRegistry(registry);
56+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation
2+
ArmSMEToLLVMIRTranslation.cpp
3+
4+
DEPENDS
5+
MLIRArmSMEConversionsIncGen
6+
7+
LINK_COMPONENTS
8+
Core
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRArmSMEDialect
13+
MLIRLLVMDialect
14+
MLIRSupport
15+
MLIRTargetLLVMIRExport
16+
)

mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(ArmNeon)
2+
add_subdirectory(ArmSME)
23
add_subdirectory(ArmSVE)
34
add_subdirectory(AMX)
45
add_subdirectory(Builtin)

0 commit comments

Comments
 (0)