Skip to content

Commit cb3a394

Browse files
authored
[mlir][ArmSME] Add tile slice to vector intrinsics (#66910)
Add support for following vector to tile (MOVA) intrinsics to ArmSME dialect: ``` llvm.aarch64.sme.read.vert llvm.aarch64.sme.read.horiz ``` This also slightly updates ArmSME_IntrOp to support return values.
1 parent 01d3045 commit cb3a394

File tree

4 files changed

+141
-4
lines changed

4 files changed

+141
-4
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
469469
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
470470

471471
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
472-
list<Trait> traits = []>
472+
list<Trait> traits = [], int numResults = 0,
473+
list<int> overloadedResults = []>
473474
: LLVM_IntrOpBase<
474475
/*Dialect dialect=*/ArmSME_Dialect,
475476
/*string opName=*/"intr." # mnemonic,
476477
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
477-
/*list<int> overloadedResults=*/[],
478+
/*list<int> overloadedResults=*/overloadedResults,
478479
/*list<int> overloadedOperands=*/overloadedOperands,
479480
/*list<Trait> traits=*/traits,
480-
/*int numResults=*/0>;
481+
/*int numResults=*/numResults>;
481482

482483
// Zero
483484
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
548549
Arguments<(ins Arg<I32, "Index">,
549550
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
550551

551-
// Vector to tile
552+
// Vector to tile slice
552553
class LLVM_aarch64_sme_write<string direction>
553554
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
554555
[AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,23 @@ class LLVM_aarch64_sme_write<string direction>
557558
Arg<SVEPredicate, "Vector predicate">:$pg,
558559
Arg<SVEVector, "Vector operand">:$vector)>;
559560

561+
// Tile slice to vector
562+
class LLVM_aarch64_sme_read<string direction>
563+
: ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
564+
[AllShapesMatch<["vector", "pg", "res"]>,
565+
AllElementTypesMatch<["vector", "res"]>],
566+
/*numResults=*/1, /*overloadedResults=*/[0]>,
567+
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
568+
Arg<SVEPredicate, "Vector predicate">:$pg,
569+
Arg<I32, "Virtual tile ID">,
570+
Arg<I32, "Tile slice">)>;
571+
560572
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
561573
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
562574

575+
def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
576+
def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
577+
563578
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
564579
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
565580

mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1414
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15+
#include "mlir/IR/TypeUtilities.h"
1516

1617
using namespace mlir;
1718
using namespace mlir::arm_sme;

mlir/test/Target/LLVMIR/arm-sme-invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,27 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
1010
(i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
1111
llvm.return
1212
}
13+
14+
// -----
15+
16+
llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
17+
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
18+
) -> vector<[3]xf32> {
19+
%tile = llvm.mlir.constant(0 : index) : i32
20+
// expected-error @+1 {{failed to verify that all of {vector, pg, res} have same shape}}
21+
%res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
22+
(vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
23+
llvm.return %res : vector<[3]xf32>
24+
}
25+
26+
// -----
27+
28+
llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
29+
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
30+
) -> vector<[3]xi32> {
31+
%tile = llvm.mlir.constant(0 : index) : i32
32+
// expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
33+
%res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
34+
(vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
35+
llvm.return %res : vector<[4]xi32>
36+
}

mlir/test/Target/LLVMIR/arm-sme.mlir

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,100 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
334334
(i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
335335
llvm.return
336336
}
337+
338+
// -----
339+
340+
341+
llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
342+
%nxv16i1 : vector<[16]xi1>,
343+
%nxv8i1 : vector<[8]xi1>,
344+
%nxv4i1 : vector<[4]xi1>,
345+
%nxv2i1 : vector<[2]xi1>,
346+
%nxv1i1 : vector<[1]xi1>,
347+
%nxv16i8 : vector<[16]xi8>,
348+
%nxv8i16 : vector<[8]xi16>,
349+
%nxv4i32 : vector<[4]xi32>,
350+
%nxv2i64 : vector<[2]xi64>,
351+
%nxv1i128 : vector<[1]xi128>,
352+
%nxv8f16 : vector<[8]xf16>,
353+
%nxv8bf16 : vector<[8]xbf16>,
354+
%nxv4f32 : vector<[4]xf32>,
355+
%nxv2f64 : vector<[2]xf64>) {
356+
%tile = llvm.mlir.constant(0 : index) : i32
357+
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
358+
%res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
359+
: (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
360+
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
361+
%res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
362+
: (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
363+
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
364+
%res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
365+
: (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
366+
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
367+
%res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
368+
: (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
369+
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
370+
%res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
371+
: (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
372+
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
373+
%res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
374+
: (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
375+
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
376+
%res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
377+
: (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
378+
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
379+
%res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
380+
: (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
381+
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
382+
%res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
383+
: (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
384+
llvm.return
385+
}
386+
387+
// -----
388+
389+
llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
390+
%nxv16i1 : vector<[16]xi1>,
391+
%nxv8i1 : vector<[8]xi1>,
392+
%nxv4i1 : vector<[4]xi1>,
393+
%nxv2i1 : vector<[2]xi1>,
394+
%nxv1i1 : vector<[1]xi1>,
395+
%nxv16i8 : vector<[16]xi8>,
396+
%nxv8i16 : vector<[8]xi16>,
397+
%nxv4i32 : vector<[4]xi32>,
398+
%nxv2i64 : vector<[2]xi64>,
399+
%nxv1i128 : vector<[1]xi128>,
400+
%nxv8f16 : vector<[8]xf16>,
401+
%nxv8bf16 : vector<[8]xbf16>,
402+
%nxv4f32 : vector<[4]xf32>,
403+
%nxv2f64 : vector<[2]xf64>) {
404+
%tile = llvm.mlir.constant(0 : index) : i32
405+
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
406+
%res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
407+
: (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
408+
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
409+
%res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
410+
: (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
411+
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
412+
%res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
413+
: (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
414+
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
415+
%res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
416+
: (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
417+
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
418+
%res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
419+
: (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
420+
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
421+
%res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
422+
: (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
423+
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
424+
%res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
425+
: (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
426+
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
427+
%res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
428+
: (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
429+
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
430+
%res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
431+
: (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
432+
llvm.return
433+
}

0 commit comments

Comments
 (0)