Skip to content

Commit 8ce23b8

Browse files
committed
[mlir][ArmSME] Add vector to tile intrinsics
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect: llvm.aarch64.sme.write.vert llvm.aarch64.sme.write.horiz Includes the definition of new type predicate 'ScalableVectorOfRankAndLengthAndType' in OpBase.td. Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D157004
1 parent ba818c4 commit 8ce23b8

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef ARMSME_OPS
1515
#define ARMSME_OPS
1616

17+
include "mlir/IR/OpBase.td"
1718
include "mlir/Interfaces/SideEffectInterfaces.td"
1819
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
1920
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -61,6 +62,12 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
6162
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
6263
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
6364

65+
def SVEVector : ScalableVectorOfRankAndLengthAndType<
66+
[1], [16, 8, 4, 2, 1], [I8, I16, I32, I64, I128, F16, BF16, F32, F64]>;
67+
68+
def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
69+
[1], [16, 8, 4, 2, 1], [I1]>;
70+
6471
// A type constraint that verifies the bitwidth of the scalar integer returned
6572
// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
6673
def TileElementWidthMatchesTileID : TypesMatchWith<
@@ -496,6 +503,18 @@ def LLVM_aarch64_sme_str
496503
Arguments<(ins Arg<I32, "Index">,
497504
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
498505

506+
// Vector to tile
507+
class LLVM_aarch64_sme_write<string direction>
508+
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
509+
[AllShapesMatch<["pg", "vector"]>]>,
510+
Arguments<(ins Arg<I32, "Virtual tile ID">,
511+
Arg<I32, "Tile slice">,
512+
Arg<SVEPredicate, "Vector predicate">:$pg,
513+
Arg<SVEVector, "Vector operand">:$vector)>;
514+
515+
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
516+
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
517+
499518
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
500519
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
501520

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,19 @@ class ScalableVectorOfLengthAndType<list<int> allowedLengths,
533533
ScalableVectorOfLength<allowedLengths>.summary,
534534
"::mlir::VectorType">;
535535

536+
// Any scalable vector where the rank is from the given `allowedRanks` list and
537+
// the number of elements is from the given `allowedLengths` list and the type
538+
// is from the given `allowedTypes` list
539+
class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
540+
list<int> allowedLengths,
541+
list<Type> allowedTypes> : AllOfType<
542+
[ScalableVectorOfRank<allowedRanks>, ScalableVectorOf<allowedTypes>,
543+
ScalableVectorOfLength<allowedLengths>],
544+
ScalableVectorOfRank<allowedRanks>.summary #
545+
ScalableVectorOf<allowedTypes>.summary #
546+
ScalableVectorOfLength<allowedLengths>.summary,
547+
"::mlir::VectorType">;
548+
536549
def AnyVector : VectorOf<[AnyType]>;
537550
// Temporary vector type clone that allows gradual transition to 0-D vectors.
538551
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
2+
3+
// Verify shape of predicate and vector must match
4+
llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
5+
%nxv4i1 : vector<[4]xi1>,
6+
%nxv16i8 : vector<[16]xi8>) {
7+
%tile = llvm.mlir.constant(0 : index) : i32
8+
// expected-error @+1 {{failed to verify that all of {pg, vector} have same shape}}
9+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
10+
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
11+
llvm.return
12+
}

mlir/test/Target/LLVMIR/arm-sme.mlir

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,101 @@ llvm.func @arm_sme_toggle_za() {
236236
"arm_sme.intr.za.disable"() : () -> ()
237237
llvm.return
238238
}
239+
240+
// -----
241+
242+
// CHECK-LABEL: @arm_sme_vector_to_tile_horiz
243+
llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
244+
%nxv16i1 : vector<[16]xi1>,
245+
%nxv8i1 : vector<[8]xi1>,
246+
%nxv4i1 : vector<[4]xi1>,
247+
%nxv2i1 : vector<[2]xi1>,
248+
%nxv1i1 : vector<[1]xi1>,
249+
%nxv16i8 : vector<[16]xi8>,
250+
%nxv8i16 : vector<[8]xi16>,
251+
%nxv4i32 : vector<[4]xi32>,
252+
%nxv2i64 : vector<[2]xi64>,
253+
%nxv1i128 : vector<[1]xi128>,
254+
%nxv8f16 : vector<[8]xf16>,
255+
%nxv8bf16 : vector<[8]xbf16>,
256+
%nxv4f32 : vector<[4]xf32>,
257+
%nxv2f64 : vector<[2]xf64>) {
258+
%tile = llvm.mlir.constant(0 : index) : i32
259+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
260+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
261+
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
262+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
263+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
264+
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
265+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
266+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
267+
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
268+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
269+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
270+
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
271+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
272+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
273+
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
274+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
275+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
276+
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
277+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
278+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
279+
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
280+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
281+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
282+
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
283+
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
284+
"arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
285+
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
286+
llvm.return
287+
}
288+
289+
// -----
290+
291+
// CHECK-LABEL: @arm_sme_vector_to_tile_vert
292+
llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
293+
%nxv16i1 : vector<[16]xi1>,
294+
%nxv8i1 : vector<[8]xi1>,
295+
%nxv4i1 : vector<[4]xi1>,
296+
%nxv2i1 : vector<[2]xi1>,
297+
%nxv1i1 : vector<[1]xi1>,
298+
%nxv16i8 : vector<[16]xi8>,
299+
%nxv8i16 : vector<[8]xi16>,
300+
%nxv4i32 : vector<[4]xi32>,
301+
%nxv2i64 : vector<[2]xi64>,
302+
%nxv1i128 : vector<[1]xi128>,
303+
%nxv8f16 : vector<[8]xf16>,
304+
%nxv8bf16 : vector<[8]xbf16>,
305+
%nxv4f32 : vector<[4]xf32>,
306+
%nxv2f64 : vector<[2]xf64>) {
307+
%tile = llvm.mlir.constant(0 : index) : i32
308+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
309+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
310+
(i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
311+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
312+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
313+
(i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
314+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
315+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
316+
(i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
317+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
318+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
319+
(i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
320+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
321+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
322+
(i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
323+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
324+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
325+
(i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
326+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
327+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
328+
(i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
329+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
330+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
331+
(i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
332+
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
333+
"arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
334+
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
335+
llvm.return
336+
}

0 commit comments

Comments
 (0)