@@ -95,6 +95,13 @@ def ScalableVectorType : ArmSVE_Type<"ScalableVector"> {
95
95
}];
96
96
}
97
97
98
+ //===----------------------------------------------------------------------===//
99
+ // Additional LLVM type constraints
100
+ //===----------------------------------------------------------------------===//
101
+ def LLVMScalableVectorType :
102
+ Type<CPred<"$_self.isa<::mlir::LLVM::LLVMScalableVectorType>()">,
103
+ "LLVM dialect scalable vector type">;
104
+
98
105
//===----------------------------------------------------------------------===//
99
106
// ArmSVE op definitions
100
107
//===----------------------------------------------------------------------===//
@@ -158,6 +165,52 @@ class ScalableIOp<string mnemonic, string op_description,
158
165
"$src1 `,` $src2 attr-dict `:` type($src1)";
159
166
}
160
167
168
+ class ScalableMaskedFOp<string mnemonic, string op_description,
169
+ list<OpTrait> traits = []> :
170
+ ArmSVE_Op<mnemonic, !listconcat(traits,
171
+ [AllTypesMatch<["src1", "src2", "res"]>,
172
+ TypesMatchWith<
173
+ "mask has i1 element type and same shape as operands",
174
+ "src1", "mask", "getI1SameShape($_self)">])> {
175
+ let summary = "masked " # op_description # " for scalable vectors of floats";
176
+ let description = [{
177
+ The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask
178
+ and two scalable vector operands, and perform floating point }] #
179
+ op_description # [{ on active lanes. Inactive lanes will keep the value of
180
+ the first operand.}];
181
+ let arguments = (ins
182
+ ScalableVectorOf<[I1]>:$mask,
183
+ ScalableVectorOf<[AnyFloat]>:$src1,
184
+ ScalableVectorOf<[AnyFloat]>:$src2
185
+ );
186
+ let results = (outs ScalableVectorOf<[AnyFloat]>:$res);
187
+ let assemblyFormat =
188
+ "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
189
+ }
190
+
191
+ class ScalableMaskedIOp<string mnemonic, string op_description,
192
+ list<OpTrait> traits = []> :
193
+ ArmSVE_Op<mnemonic, !listconcat(traits,
194
+ [AllTypesMatch<["src1", "src2", "res"]>,
195
+ TypesMatchWith<
196
+ "mask has i1 element type and same shape as operands",
197
+ "src1", "mask", "getI1SameShape($_self)">])> {
198
+ let summary = "masked " # op_description # " for scalable vectors of integers";
199
+ let description = [{
200
+ The `arm_sve.}] # mnemonic # [{` operation takes one scalable vector mask
201
+ and two scalable vector operands, and perform integer }] #
202
+ op_description # [{ on active lanes. Inactive lanes will keep the value of
203
+ the first operand.}];
204
+ let arguments = (ins
205
+ ScalableVectorOf<[I1]>:$mask,
206
+ ScalableVectorOf<[I8, I16, I32, I64]>:$src1,
207
+ ScalableVectorOf<[I8, I16, I32, I64]>:$src2
208
+ );
209
+ let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$res);
210
+ let assemblyFormat =
211
+ "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
212
+ }
213
+
161
214
def SdotOp : ArmSVE_Op<"sdot",
162
215
[NoSideEffect,
163
216
AllTypesMatch<["src1", "src2"]>,
@@ -321,21 +374,94 @@ def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">;
321
374
322
375
def ScalableDivFOp : ScalableFOp<"divf", "division">;
323
376
377
+ def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
378
+ [Commutative]>;
379
+
380
+ def ScalableMaskedAddFOp : ScalableMaskedFOp<"masked.addf", "addition",
381
+ [Commutative]>;
382
+
383
+ def ScalableMaskedSubIOp : ScalableMaskedIOp<"masked.subi", "subtraction">;
384
+
385
+ def ScalableMaskedSubFOp : ScalableMaskedFOp<"masked.subf", "subtraction">;
386
+
387
+ def ScalableMaskedMulIOp : ScalableMaskedIOp<"masked.muli", "multiplication",
388
+ [Commutative]>;
389
+
390
+ def ScalableMaskedMulFOp : ScalableMaskedFOp<"masked.mulf", "multiplication",
391
+ [Commutative]>;
392
+
393
+ def ScalableMaskedSDivIOp : ScalableMaskedIOp<"masked.divi_signed",
394
+ "signed division">;
395
+
396
+ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
397
+ "unsigned division">;
398
+
399
+ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
400
+
324
401
def UmmlaIntrOp :
325
402
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
326
- Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
403
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
404
+ LLVMScalableVectorType)>;
327
405
328
406
def SmmlaIntrOp :
329
407
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
330
- Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
408
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
409
+ LLVMScalableVectorType)>;
331
410
332
411
def SdotIntrOp :
333
412
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
334
- Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
413
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
414
+ LLVMScalableVectorType)>;
335
415
336
416
def UdotIntrOp :
337
417
ArmSVE_IntrBinaryOverloadedOp<"udot">,
338
- Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>;
418
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
419
+ LLVMScalableVectorType)>;
420
+
421
+ def ScalableMaskedAddIIntrOp :
422
+ ArmSVE_IntrBinaryOverloadedOp<"add">,
423
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
424
+ LLVMScalableVectorType)>;
425
+
426
+ def ScalableMaskedAddFIntrOp :
427
+ ArmSVE_IntrBinaryOverloadedOp<"fadd">,
428
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
429
+ LLVMScalableVectorType)>;
430
+
431
+ def ScalableMaskedMulIIntrOp :
432
+ ArmSVE_IntrBinaryOverloadedOp<"mul">,
433
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
434
+ LLVMScalableVectorType)>;
435
+
436
+ def ScalableMaskedMulFIntrOp :
437
+ ArmSVE_IntrBinaryOverloadedOp<"fmul">,
438
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
439
+ LLVMScalableVectorType)>;
440
+
441
+ def ScalableMaskedSubIIntrOp :
442
+ ArmSVE_IntrBinaryOverloadedOp<"sub">,
443
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
444
+ LLVMScalableVectorType)>;
445
+
446
+ def ScalableMaskedSubFIntrOp :
447
+ ArmSVE_IntrBinaryOverloadedOp<"fsub">,
448
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
449
+ LLVMScalableVectorType)>;
450
+
451
+ def ScalableMaskedSDivIIntrOp :
452
+ ArmSVE_IntrBinaryOverloadedOp<"sdiv">,
453
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
454
+ LLVMScalableVectorType)>;
455
+
456
+ def ScalableMaskedUDivIIntrOp :
457
+ ArmSVE_IntrBinaryOverloadedOp<"udiv">,
458
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
459
+ LLVMScalableVectorType)>;
460
+
461
+ def ScalableMaskedDivFIntrOp :
462
+ ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
463
+ Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType,
464
+ LLVMScalableVectorType)>;
339
465
340
466
def VectorScaleIntrOp:
341
467
ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">;
0 commit comments