@@ -34,25 +34,6 @@ def X86Vector_Dialect : Dialect {
34
34
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
35
35
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
36
36
37
- // Intrinsic operation used during lowering to LLVM IR.
38
- class AVX512_IntrOp<string mnemonic, int numResults,
39
- list<Trait> traits = [],
40
- string extension = ""> :
41
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42
- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43
- [], [], traits, numResults>;
44
-
45
- // Defined by first result overload. May have to be extended for other
46
- // instructions in the future.
47
- class AVX512_IntrOverloadedOp<string mnemonic,
48
- list<Trait> traits = [],
49
- string extension = ""> :
50
- LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51
- !subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52
- /*list<int> overloadedResults=*/[0],
53
- /*list<int> overloadedOperands=*/[],
54
- traits, /*numResults=*/1>;
55
-
56
37
//----------------------------------------------------------------------------//
57
38
// MaskCompressOp
58
39
//----------------------------------------------------------------------------//
@@ -91,21 +72,14 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
91
72
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
92
73
" `:` type($dst) (`,` type($src)^)?";
93
74
let hasVerifier = 1;
94
- }
95
75
96
- def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97
- Pure,
98
- AllTypesMatch<["a", "src", "res"]>,
99
- TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100
- "res", "k",
101
- "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102
- "IntegerType::get($_self.getContext(), 1))">]> {
103
- let arguments = (ins VectorOfLengthAndType<[16, 8],
104
- [F32, I32, F64, I64]>:$a,
105
- VectorOfLengthAndType<[16, 8],
106
- [F32, I32, F64, I64]>:$src,
107
- VectorOfLengthAndType<[16, 8],
108
- [I1]>:$k);
76
+ let extraClassDeclaration = [{
77
+ /// Return LLVM intrinsic function name matching op variant.
78
+ std::string getIntrinsicName() {
79
+ // Overload is resolved later by intrisic call lowering.
80
+ return "llvm.x86.avx512.mask.compress";
81
+ }
82
+ }];
109
83
}
110
84
111
85
//----------------------------------------------------------------------------//
@@ -142,26 +116,21 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
142
116
let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143
117
let assemblyFormat =
144
118
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145
- }
146
119
147
- def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148
- Pure,
149
- AllTypesMatch<["src", "a", "res"]>]> {
150
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151
- I32:$k,
152
- VectorOfLengthAndType<[16], [F32]>:$a,
153
- I16:$imm,
154
- LLVM_Type:$rounding);
155
- }
156
-
157
- def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158
- Pure,
159
- AllTypesMatch<["src", "a", "res"]>]> {
160
- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161
- I32:$k,
162
- VectorOfLengthAndType<[8], [F64]>:$a,
163
- I8:$imm,
164
- LLVM_Type:$rounding);
120
+ let extraClassDeclaration = [{
121
+ /// Return LLVM intrinsic function name matching op variant.
122
+ std::string getIntrinsicName() {
123
+ std::string intr = "llvm.x86.avx512.mask.rndscale";
124
+ VectorType vecType = getSrc().getType();
125
+ Type elemType = vecType.getElementType();
126
+ intr += ".";
127
+ intr += elemType.isF32() ? "ps" : "pd";
128
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
129
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
130
+ intr += "." + std::to_string(opBitWidth);
131
+ return intr;
132
+ }
133
+ }];
165
134
}
166
135
167
136
//----------------------------------------------------------------------------//
@@ -199,26 +168,21 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
199
168
// Fully specified by traits.
200
169
let assemblyFormat =
201
170
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202
- }
203
-
204
- def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205
- Pure,
206
- AllTypesMatch<["src", "a", "b", "res"]>]> {
207
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208
- VectorOfLengthAndType<[16], [F32]>:$a,
209
- VectorOfLengthAndType<[16], [F32]>:$b,
210
- I16:$k,
211
- LLVM_Type:$rounding);
212
- }
213
171
214
- def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215
- Pure,
216
- AllTypesMatch<["src", "a", "b", "res"]>]> {
217
- let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218
- VectorOfLengthAndType<[8], [F64]>:$a,
219
- VectorOfLengthAndType<[8], [F64]>:$b,
220
- I8:$k,
221
- LLVM_Type:$rounding);
172
+ let extraClassDeclaration = [{
173
+ /// Return LLVM intrinsic function name matching op variant.
174
+ std::string getIntrinsicName() {
175
+ std::string intr = "llvm.x86.avx512.mask.scalef";
176
+ VectorType vecType = getSrc().getType();
177
+ Type elemType = vecType.getElementType();
178
+ intr += ".";
179
+ intr += elemType.isF32() ? "ps" : "pd";
180
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
181
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
182
+ intr += "." + std::to_string(opBitWidth);
183
+ return intr;
184
+ }
185
+ }];
222
186
}
223
187
224
188
//----------------------------------------------------------------------------//
@@ -260,18 +224,21 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
260
224
);
261
225
let assemblyFormat =
262
226
"$a `,` $b attr-dict `:` type($a)";
263
- }
264
-
265
- def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266
- Pure]> {
267
- let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268
- VectorOfLengthAndType<[16], [I32]>:$b);
269
- }
270
227
271
- def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272
- Pure]> {
273
- let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274
- VectorOfLengthAndType<[8], [I64]>:$b);
228
+ let extraClassDeclaration = [{
229
+ /// Return LLVM intrinsic function name matching op variant.
230
+ std::string getIntrinsicName() {
231
+ std::string intr = "llvm.x86.avx512.vp2intersect";
232
+ VectorType vecType = getA().getType();
233
+ Type elemType = vecType.getElementType();
234
+ intr += ".";
235
+ intr += elemType.isInteger(32) ? "d" : "q";
236
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
237
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
238
+ intr += "." + std::to_string(opBitWidth);
239
+ return intr;
240
+ }
241
+ }];
275
242
}
276
243
277
244
//----------------------------------------------------------------------------//
@@ -299,7 +266,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
299
266
300
267
Example:
301
268
```mlir
302
- %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
269
+ %dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303
270
```
304
271
}];
305
272
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -309,36 +276,18 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
309
276
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310
277
let assemblyFormat =
311
278
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312
- }
313
-
314
- def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315
- AllTypesMatch<["a", "b"]>,
316
- AllTypesMatch<["src", "res"]>],
317
- /*extension=*/"bf16"> {
318
- let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319
- VectorOfLengthAndType<[8], [BF16]>:$a,
320
- VectorOfLengthAndType<[8], [BF16]>:$b);
321
- let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322
- }
323
279
324
- def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325
- AllTypesMatch<["a", "b"]>,
326
- AllTypesMatch<["src", "res"]>],
327
- /*extension=*/"bf16"> {
328
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329
- VectorOfLengthAndType<[16], [BF16]>:$a,
330
- VectorOfLengthAndType<[16], [BF16]>:$b);
331
- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332
- }
333
-
334
- def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335
- AllTypesMatch<["a", "b"]>,
336
- AllTypesMatch<["src", "res"]>],
337
- /*extension=*/"bf16"> {
338
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339
- VectorOfLengthAndType<[32], [BF16]>:$a,
340
- VectorOfLengthAndType<[32], [BF16]>:$b);
341
- let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
280
+ let extraClassDeclaration = [{
281
+ /// Return LLVM intrinsic function name matching op variant.
282
+ std::string getIntrinsicName() {
283
+ std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
284
+ VectorType vecType = getSrc().getType();
285
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
286
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
287
+ intr += "." + std::to_string(opBitWidth);
288
+ return intr;
289
+ }
290
+ }];
342
291
}
343
292
344
293
//----------------------------------------------------------------------------//
@@ -367,18 +316,18 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
367
316
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368
317
let assemblyFormat =
369
318
"$a attr-dict `:` type($a) `->` type($dst)";
370
- }
371
319
372
- def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373
- /*extension=*/"bf16"> {
374
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375
- let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376
- }
377
-
378
- def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379
- /*extension=*/"bf16"> {
380
- let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381
- let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
320
+ let extraClassDeclaration = [{
321
+ /// Return LLVM intrinsic function name matching op variant.
322
+ std::string getIntrinsicName() {
323
+ std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
324
+ VectorType vecType = getA().getType();
325
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
326
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
327
+ intr += "." + std::to_string(opBitWidth);
328
+ return intr;
329
+ }
330
+ }];
382
331
}
383
332
384
333
//===----------------------------------------------------------------------===//
@@ -395,12 +344,6 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
395
344
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
396
345
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
397
346
398
- // Intrinsic operation used during lowering to LLVM IR.
399
- class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
400
- LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
401
- "x86_avx_" # !subst(".", "_", mnemonic),
402
- [], [], traits, numResults>;
403
-
404
347
//----------------------------------------------------------------------------//
405
348
// AVX Rsqrt
406
349
//----------------------------------------------------------------------------//
@@ -410,11 +353,13 @@ def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
410
353
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
411
354
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
412
355
let assemblyFormat = "$a attr-dict `:` type($a)";
413
- }
414
356
415
- def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
416
- SameOperandsAndResultType]> {
417
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
357
+ let extraClassDeclaration = [{
358
+ /// Return LLVM intrinsic function name matching op variant.
359
+ std::string getIntrinsicName() {
360
+ return "llvm.x86.avx.rsqrt.ps.256";
361
+ }
362
+ }];
418
363
}
419
364
420
365
//----------------------------------------------------------------------------//
@@ -443,13 +388,13 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
443
388
VectorOfLengthAndType<[8], [F32]>:$b);
444
389
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
445
390
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
446
- }
447
391
448
- def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
449
- AllTypesMatch<["a", "b", "res"]>]> {
450
- let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
451
- VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
452
- let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
392
+ let extraClassDeclaration = [{
393
+ /// Return LLVM intrinsic function name matching op variant.
394
+ std::string getIntrinsicName() {
395
+ return "llvm.x86.avx.dp.ps.256";
396
+ }
397
+ }];
453
398
}
454
399
455
400
#endif // X86VECTOR_OPS
0 commit comments