Skip to content

Commit 7dcca62

Browse files

File tree

5 files changed

+234
-3
lines changed

5 files changed

+234
-3
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
4949
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
5050
[16, 8, 4, 2, 1], [I1]>;
5151

52+
// A constraint for a 1-D scalable vector of `length`.
53+
class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedContainerType<
54+
elementTypes, And<[IsVectorOfShape<[length]>, IsVectorTypeWithAnyDimScalablePred]>,
55+
"a 1-D scalable vector with length " # length,
56+
"::mlir::VectorType">;
57+
5258
//===----------------------------------------------------------------------===//
5359
// ArmSVE op definitions
5460
//===----------------------------------------------------------------------===//
@@ -321,6 +327,121 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
321327
let assemblyFormat = "$source attr-dict `:` type($source)";
322328
}
323329

330+
// Inputs valid for the multi-vector zips (not including the 128-bit element zipqs)
331+
def ZipInputVectorType : AnyTypeOf<[
332+
Scalable1DVectorOfLength<2, [I64, F64]>,
333+
Scalable1DVectorOfLength<4, [I32, F32]>,
334+
Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
335+
Scalable1DVectorOfLength<16, [I8]>],
336+
"an SVE vector with element size <= 64-bit">;
337+
338+
def ZipX2Op : ArmSVE_Op<"zip.x2", [
339+
Pure,
340+
AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
341+
> {
342+
let summary = "Multi-vector two-way zip op";
343+
344+
let description = [{
345+
This operation interleaves elements from two input SVE vectors, returning
346+
two new SVE vectors (`resultV1` and `resultV2`), which contain the low and
347+
high halves of the result respectively.
348+
349+
Example:
350+
```mlir
351+
// sourceV1 = [ A1, A2, A3, ... An ]
352+
// sourceV2 = [ B1, B2, B3, ... Bn ]
353+
// (resultV1, resultV2) = [ A1, B1, A2, B2, A3, B3, ... An, Bn ]
354+
%resultV1, %resultV2 = arm_sve.zip.x2 %sourceV1, %sourceV2 : vector<[16]xi8>
355+
```
356+
357+
Note: This requires SME 2 (`+sme2` in LLVM target features)
358+
359+
[Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en)
360+
}];
361+
362+
let arguments = (ins ZipInputVectorType:$sourceV1,
363+
ZipInputVectorType:$sourceV2);
364+
365+
let results = (outs ZipInputVectorType:$resultV1,
366+
ZipInputVectorType:$resultV2);
367+
368+
let builders = [
369+
OpBuilder<(ins "Value":$v1, "Value":$v2), [{
370+
build($_builder, $_state, v1.getType(), v1.getType(), v1, v2);
371+
}]>];
372+
373+
let assemblyFormat = "$sourceV1 `,` $sourceV2 attr-dict `:` type($sourceV1)";
374+
375+
let extraClassDeclaration = [{
376+
VectorType getVectorType() {
377+
return ::llvm::cast<VectorType>(getSourceV1().getType());
378+
}
379+
}];
380+
}
381+
382+
def ZipX4Op : ArmSVE_Op<"zip.x4", [
383+
Pure,
384+
AllTypesMatch<[
385+
"sourceV1", "sourceV2", "sourceV3", "sourceV4",
386+
"resultV1", "resultV2", "resultV3", "resultV4"]>]
387+
> {
388+
let summary = "Multi-vector four-way zip op";
389+
390+
let description = [{
391+
This operation interleaves elements from four input SVE vectors, returning
392+
four new SVE vectors, each of which contain a quarter of the result. The
393+
first quarter will be in `resultV1`, second in `resultV2`, third in
394+
`resultV3`, and fourth in `resultV4`.
395+
396+
```mlir
397+
// sourceV1 = [ A1, A2, ... An ]
398+
// sourceV2 = [ B1, B2, ... Bn ]
399+
// sourceV3 = [ C1, C2, ... Cn ]
400+
// sourceV4 = [ D1, D2, ... Dn ]
401+
// (resultV1, resultV2, resultV3, resultV4)
402+
// = [ A1, B1, C1, D1, A2, B2, C2, D2, ... An, Bn, Cn, Dn ]
403+
%resultV1, %resultV2, %resultV3, %resultV4 = arm_sve.zip.x4
404+
%sourceV1, %sourceV2, %sourceV3, %sourceV4 : vector<[16]xi8>
405+
```
406+
407+
**Warning:** The result of this op is undefined for 64-bit elements on
408+
hardware with less than 256-bit vectors!
409+
410+
Note: This requires SME 2 (`+sme2` in LLVM target features)
411+
412+
[Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en)
413+
}];
414+
415+
let arguments = (ins ZipInputVectorType:$sourceV1,
416+
ZipInputVectorType:$sourceV2,
417+
ZipInputVectorType:$sourceV3,
418+
ZipInputVectorType:$sourceV4);
419+
420+
let results = (outs ZipInputVectorType:$resultV1,
421+
ZipInputVectorType:$resultV2,
422+
ZipInputVectorType:$resultV3,
423+
ZipInputVectorType:$resultV4);
424+
425+
let builders = [
426+
OpBuilder<(ins "Value":$v1, "Value":$v2, "Value":$v3, "Value":$v4), [{
427+
build($_builder, $_state,
428+
v1.getType(), v1.getType(),
429+
v1.getType(), v1.getType(),
430+
v1, v2, v3, v4);
431+
}]>];
432+
433+
let assemblyFormat = [{
434+
$sourceV1 `,` $sourceV2 `,` $sourceV3 `,` $sourceV4 attr-dict
435+
`:` type($sourceV1)
436+
}];
437+
438+
let extraClassDeclaration = [{
439+
VectorType getVectorType() {
440+
return ::llvm::cast<VectorType>(getSourceV1().getType());
441+
}
442+
}];
443+
}
444+
324445
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
325446
[Commutative]>;
326447

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ using ConvertToSvboolOpLowering =
137137
using ConvertFromSvboolOpLowering =
138138
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139139

140+
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
141+
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
142+
140143
} // namespace
141144

142145
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -163,7 +166,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
163166
ScalableMaskedUDivIOpLowering,
164167
ScalableMaskedDivFOpLowering,
165168
ConvertToSvboolOpLowering,
166-
ConvertFromSvboolOpLowering>(converter);
169+
ConvertFromSvboolOpLowering,
170+
ZipX2OpLowering,
171+
ZipX4OpLowering>(converter);
167172
// clang-format on
168173
}
169174

@@ -184,7 +189,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
184189
ScalableMaskedUDivIIntrOp,
185190
ScalableMaskedDivFIntrOp,
186191
ConvertToSvboolIntrOp,
187-
ConvertFromSvboolIntrOp>();
192+
ConvertFromSvboolIntrOp,
193+
ZipX2IntrOp,
194+
ZipX4IntrOp>();
188195
target.addIllegalOp<SdotOp,
189196
SmmlaOp,
190197
UdotOp,
@@ -199,6 +206,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
199206
ScalableMaskedUDivIOp,
200207
ScalableMaskedDivFOp,
201208
ConvertToSvboolOp,
202-
ConvertFromSvboolOp>();
209+
ConvertFromSvboolOp,
210+
ZipX2Op,
211+
ZipX4Op>();
203212
// clang-format on
204213
}

mlir/test/Dialect/ArmSVE/invalid.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,18 @@ func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8
4949
}
5050

5151

52+
// -----
53+
54+
func.func @arm_sve_zip_x2_bad_vector_type(%a : vector<[7]xi8>) {
55+
// expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[7]xi8>'}}
56+
arm_sve.zip.x2 %a, %a : vector<[7]xi8>
57+
return
58+
}
59+
60+
// -----
61+
62+
func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
63+
// expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[5]xf64>'}}
64+
arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
65+
return
66+
}

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,27 @@ func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[
187187
// CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
188188
return %mask : vector<3x[1]xi1>
189189
}
190+
191+
// -----
192+
193+
func.func @arm_sve_zip_x2(%a: vector<[8]xi16>, %b: vector<[8]xi16>)
194+
-> (vector<[8]xi16>, vector<[8]xi16>)
195+
{
196+
// CHECK: arm_sve.intr.zip.x2
197+
%0, %1 = arm_sve.zip.x2 %a, %b : vector<[8]xi16>
198+
return %0, %1 : vector<[8]xi16>, vector<[8]xi16>
199+
}
200+
201+
// -----
202+
203+
func.func @arm_sve_zip_x4(
204+
%a: vector<[16]xi8>,
205+
%b: vector<[16]xi8>,
206+
%c: vector<[16]xi8>,
207+
%d: vector<[16]xi8>
208+
) -> (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
209+
{
210+
// CHECK: arm_sve.intr.zip.x4
211+
%0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
212+
return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
213+
}

mlir/test/Dialect/ArmSVE/roundtrip.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,65 @@ func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
163163

164164
return
165165
}
166+
167+
// -----
168+
169+
func.func @arm_sve_zip_x2(
170+
%v1: vector<[2]xi64>,
171+
%v2: vector<[2]xf64>,
172+
%v3: vector<[4]xi32>,
173+
%v4: vector<[4]xf32>,
174+
%v5: vector<[8]xi16>,
175+
%v6: vector<[8]xf16>,
176+
%v7: vector<[8]xbf16>,
177+
%v8: vector<[16]xi8>
178+
) {
179+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xi64>
180+
%a1, %b1 = arm_sve.zip.x2 %v1, %v1 : vector<[2]xi64>
181+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xf64>
182+
%a2, %b2 = arm_sve.zip.x2 %v2, %v2 : vector<[2]xf64>
183+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xi32>
184+
%a3, %b3 = arm_sve.zip.x2 %v3, %v3 : vector<[4]xi32>
185+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xf32>
186+
%a4, %b4 = arm_sve.zip.x2 %v4, %v4 : vector<[4]xf32>
187+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xi16>
188+
%a5, %b5 = arm_sve.zip.x2 %v5, %v5 : vector<[8]xi16>
189+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xf16>
190+
%a6, %b6 = arm_sve.zip.x2 %v6, %v6 : vector<[8]xf16>
191+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xbf16>
192+
%a7, %b7 = arm_sve.zip.x2 %v7, %v7 : vector<[8]xbf16>
193+
// CHECK: arm_sve.zip.x2 %{{.*}} : vector<[16]xi8>
194+
%a8, %b8 = arm_sve.zip.x2 %v8, %v8 : vector<[16]xi8>
195+
return
196+
}
197+
198+
// -----
199+
200+
func.func @arm_sve_zip_x4(
201+
%v1: vector<[2]xi64>,
202+
%v2: vector<[2]xf64>,
203+
%v3: vector<[4]xi32>,
204+
%v4: vector<[4]xf32>,
205+
%v5: vector<[8]xi16>,
206+
%v6: vector<[8]xf16>,
207+
%v7: vector<[8]xbf16>,
208+
%v8: vector<[16]xi8>
209+
) {
210+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xi64>
211+
%a1, %b1, %c1, %d1 = arm_sve.zip.x4 %v1, %v1, %v1, %v1 : vector<[2]xi64>
212+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xf64>
213+
%a2, %b2, %c2, %d2 = arm_sve.zip.x4 %v2, %v2, %v2, %v2 : vector<[2]xf64>
214+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xi32>
215+
%a3, %b3, %c3, %d3 = arm_sve.zip.x4 %v3, %v3, %v3, %v3 : vector<[4]xi32>
216+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xf32>
217+
%a4, %b4, %c4, %d4 = arm_sve.zip.x4 %v4, %v4, %v4, %v4 : vector<[4]xf32>
218+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xi16>
219+
%a5, %b5, %c5, %d5 = arm_sve.zip.x4 %v5, %v5, %v5, %v5 : vector<[8]xi16>
220+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xf16>
221+
%a6, %b6, %c6, %d6 = arm_sve.zip.x4 %v6, %v6, %v6, %v6 : vector<[8]xf16>
222+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xbf16>
223+
%a7, %b7, %c7, %d7 = arm_sve.zip.x4 %v7, %v7, %v7, %v7 : vector<[8]xbf16>
224+
// CHECK: arm_sve.zip.x4 %{{.*}} : vector<[16]xi8>
225+
%a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
226+
return
227+
}

0 commit comments

Comments
 (0)