Skip to content

Commit 2e4161d

Browse files
authored
[mlir][ArmSME] Name arguments of SME intrinsics (NFC) (llvm#69608)
This makes the docs a little nicer to read, as these otherwise show up as "«unnamed»". The extra include is needed as naming means getters are generated, and the getters use the LLVM types.
1 parent b9dae2f commit 2e4161d

File tree

3 files changed

+27
-26
lines changed

3 files changed

+27
-26
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1718
#include "mlir/Dialect/SCF/IR/SCF.h"
1819
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1920
#include "mlir/IR/BuiltinTypes.h"

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
3838

3939
// Zero
4040
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
41-
Arguments<(ins Arg<I32, "Tile mask">)>;
41+
Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
4242

4343
// MOP's
4444
class ArmSME_IntrMopOverloadedOp<string mnemonic>
4545
: ArmSME_IntrOp<mnemonic, [4]>,
46-
Arguments<(ins Arg<I32, "Virtual tile ID">,
47-
Arg<MOPPredicate, "LHS predicate">,
48-
Arg<MOPPredicate, "RHS predicate">,
49-
Arg<MOPVector, "LHS vector operand">,
50-
Arg<MOPVector, "RHS vector operand">)>;
46+
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
47+
Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
48+
Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
49+
Arg<MOPVector, "LHS vector operand">:$lhs_vector,
50+
Arg<MOPVector, "RHS vector operand">:$rhs_vector)>;
5151

5252
def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
5353
def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
@@ -65,10 +65,10 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
6565
// Loads
6666
class ArmSME_IntrLoadOp<string mnemonic>
6767
: ArmSME_IntrOp<mnemonic>,
68-
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
69-
Arg<LLVM_AnyPointer, "Load address">,
70-
Arg<I32, "Virtual tile ID">,
71-
Arg<I32, "Tile slice">)>;
68+
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
69+
Arg<LLVM_AnyPointer, "Load address">:$load_address,
70+
Arg<I32, "Virtual tile ID">:$tile_id,
71+
Arg<I32, "Tile slice">:$tile_slice_index)>;
7272

7373
def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
7474
def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
@@ -84,10 +84,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
8484
// Stores
8585
class ArmSME_IntrStoreOp<string mnemonic>
8686
: ArmSME_IntrOp<mnemonic>,
87-
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
88-
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
89-
Arg<I32, "Virtual tile ID">,
90-
Arg<I32, "Tile slice">)>;
87+
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">:$predicate,
88+
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
89+
Arg<I32, "Virtual tile ID">:$tild_id,
90+
Arg<I32, "Tile slice">:$tile_slice_index)>;
9191

9292
def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
9393
def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
@@ -102,28 +102,28 @@ def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
102102

103103
def LLVM_aarch64_sme_str
104104
: ArmSME_IntrOp<"str">,
105-
Arguments<(ins Arg<I32, "Index">,
106-
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
105+
Arguments<(ins Arg<I32, "Index">:$index,
106+
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address)>;
107107

108108
// Vector to tile slice
109109
class LLVM_aarch64_sme_write<string direction>
110110
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
111-
[AllShapesMatch<["pg", "vector"]>]>,
112-
Arguments<(ins Arg<I32, "Virtual tile ID">,
113-
Arg<I32, "Tile slice">,
114-
Arg<SVEPredicate, "Vector predicate">:$pg,
111+
[AllShapesMatch<["predicate", "vector"]>]>,
112+
Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
113+
Arg<I32, "Tile slice">:$tile_slice_index,
114+
Arg<SVEPredicate, "Vector predicate">:$predicate,
115115
Arg<SVEVector, "Vector operand">:$vector)>;
116116

117117
// Tile slice to vector
118118
class LLVM_aarch64_sme_read<string direction>
119119
: ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
120-
[AllShapesMatch<["vector", "pg", "res"]>,
120+
[AllShapesMatch<["vector", "predicate", "res"]>,
121121
AllElementTypesMatch<["vector", "res"]>],
122122
/*numResults=*/1, /*overloadedResults=*/[0]>,
123123
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
124-
Arg<SVEPredicate, "Vector predicate">:$pg,
125-
Arg<I32, "Virtual tile ID">,
126-
Arg<I32, "Tile slice">)>;
124+
Arg<SVEPredicate, "Vector predicate">:$predicate,
125+
Arg<I32, "Virtual tile ID">:$tile_id,
126+
Arg<I32, "Tile slice">:$tile_slice_index)>;
127127

128128
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
129129
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;

mlir/test/Target/LLVMIR/arm-sme-invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
55
%nxv4i1 : vector<[4]xi1>,
66
%nxv16i8 : vector<[16]xi8>) {
77
%tile = llvm.mlir.constant(0 : index) : i32
8-
// expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
8+
// expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
99
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
1010
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
1111
llvm.return
@@ -17,7 +17,7 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
1717
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
1818
) -> vector<[3]xf32> {
1919
%tile = llvm.mlir.constant(0 : index) : i32
20-
// expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
20+
// expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
2121
%res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
2222
(vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
2323
llvm.return %res : vector<[3]xf32>

0 commit comments

Comments
 (0)