Skip to content

Commit bd2621b

Browse files
committed
[mlir][x86vector] Simplify intrinsic generation
Replaces separate x86vector named intrinsic operations with direct calls to LLVM intrinsic functions. This rework reduces the number of named ops leaving only high-level MLIR equivalents of whole intrinsic classes e.g., variants of AVX512 dot on BF16 inputs. Dialect conversion applies LLVM intrinsic name mangling further simplifying lowering logic. The separate conversion step translating x86vector intrinsics into LLVM IR is also eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure.
1 parent 5c65a32 commit bd2621b

File tree

12 files changed

+228
-524
lines changed

12 files changed

+228
-524
lines changed
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
add_mlir_dialect(X86Vector x86vector)
22
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
3-
4-
set(LLVM_TARGET_DEFINITIONS X86Vector.td)
5-
mlir_tablegen(X86VectorConversions.inc -gen-llvmir-conversions)
6-
add_public_tablegen_target(MLIRX86VectorConversionsIncGen)

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 84 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,6 @@ def X86Vector_Dialect : Dialect {
3434
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
3535
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
3636

37-
// Intrinsic operation used during lowering to LLVM IR.
38-
class AVX512_IntrOp<string mnemonic, int numResults,
39-
list<Trait> traits = [],
40-
string extension = ""> :
41-
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
42-
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
43-
[], [], traits, numResults>;
44-
45-
// Defined by first result overload. May have to be extended for other
46-
// instructions in the future.
47-
class AVX512_IntrOverloadedOp<string mnemonic,
48-
list<Trait> traits = [],
49-
string extension = ""> :
50-
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
51-
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
52-
/*list<int> overloadedResults=*/[0],
53-
/*list<int> overloadedOperands=*/[],
54-
traits, /*numResults=*/1>;
55-
5637
//----------------------------------------------------------------------------//
5738
// MaskCompressOp
5839
//----------------------------------------------------------------------------//
@@ -91,21 +72,14 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
9172
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
9273
" `:` type($dst) (`,` type($src)^)?";
9374
let hasVerifier = 1;
94-
}
9575

96-
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
97-
Pure,
98-
AllTypesMatch<["a", "src", "res"]>,
99-
TypesMatchWith<"`k` has the same number of bits as elements in `res`",
100-
"res", "k",
101-
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
102-
"IntegerType::get($_self.getContext(), 1))">]> {
103-
let arguments = (ins VectorOfLengthAndType<[16, 8],
104-
[F32, I32, F64, I64]>:$a,
105-
VectorOfLengthAndType<[16, 8],
106-
[F32, I32, F64, I64]>:$src,
107-
VectorOfLengthAndType<[16, 8],
108-
[I1]>:$k);
76+
let extraClassDeclaration = [{
77+
/// Return LLVM intrinsic function name matching op variant.
78+
std::string getIntrinsicName() {
79+
// Overload is resolved later by intrisic call lowering.
80+
return "llvm.x86.avx512.mask.compress";
81+
}
82+
}];
10983
}
11084

11185
//----------------------------------------------------------------------------//
@@ -142,26 +116,21 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
142116
let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
143117
let assemblyFormat =
144118
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
145-
}
146119

147-
def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
148-
Pure,
149-
AllTypesMatch<["src", "a", "res"]>]> {
150-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
151-
I32:$k,
152-
VectorOfLengthAndType<[16], [F32]>:$a,
153-
I16:$imm,
154-
LLVM_Type:$rounding);
155-
}
156-
157-
def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
158-
Pure,
159-
AllTypesMatch<["src", "a", "res"]>]> {
160-
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
161-
I32:$k,
162-
VectorOfLengthAndType<[8], [F64]>:$a,
163-
I8:$imm,
164-
LLVM_Type:$rounding);
120+
let extraClassDeclaration = [{
121+
/// Return LLVM intrinsic function name matching op variant.
122+
std::string getIntrinsicName() {
123+
std::string intr = "llvm.x86.avx512.mask.rndscale";
124+
VectorType vecType = getSrc().getType();
125+
Type elemType = vecType.getElementType();
126+
intr += ".";
127+
intr += elemType.isF32() ? "ps" : "pd";
128+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
129+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
130+
intr += "." + std::to_string(opBitWidth);
131+
return intr;
132+
}
133+
}];
165134
}
166135

167136
//----------------------------------------------------------------------------//
@@ -199,26 +168,21 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
199168
// Fully specified by traits.
200169
let assemblyFormat =
201170
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
202-
}
203-
204-
def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
205-
Pure,
206-
AllTypesMatch<["src", "a", "b", "res"]>]> {
207-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
208-
VectorOfLengthAndType<[16], [F32]>:$a,
209-
VectorOfLengthAndType<[16], [F32]>:$b,
210-
I16:$k,
211-
LLVM_Type:$rounding);
212-
}
213171

214-
def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
215-
Pure,
216-
AllTypesMatch<["src", "a", "b", "res"]>]> {
217-
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
218-
VectorOfLengthAndType<[8], [F64]>:$a,
219-
VectorOfLengthAndType<[8], [F64]>:$b,
220-
I8:$k,
221-
LLVM_Type:$rounding);
172+
let extraClassDeclaration = [{
173+
/// Return LLVM intrinsic function name matching op variant.
174+
std::string getIntrinsicName() {
175+
std::string intr = "llvm.x86.avx512.mask.scalef";
176+
VectorType vecType = getSrc().getType();
177+
Type elemType = vecType.getElementType();
178+
intr += ".";
179+
intr += elemType.isF32() ? "ps" : "pd";
180+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
181+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
182+
intr += "." + std::to_string(opBitWidth);
183+
return intr;
184+
}
185+
}];
222186
}
223187

224188
//----------------------------------------------------------------------------//
@@ -260,18 +224,21 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
260224
);
261225
let assemblyFormat =
262226
"$a `,` $b attr-dict `:` type($a)";
263-
}
264-
265-
def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
266-
Pure]> {
267-
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
268-
VectorOfLengthAndType<[16], [I32]>:$b);
269-
}
270227

271-
def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
272-
Pure]> {
273-
let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
274-
VectorOfLengthAndType<[8], [I64]>:$b);
228+
let extraClassDeclaration = [{
229+
/// Return LLVM intrinsic function name matching op variant.
230+
std::string getIntrinsicName() {
231+
std::string intr = "llvm.x86.avx512.vp2intersect";
232+
VectorType vecType = getA().getType();
233+
Type elemType = vecType.getElementType();
234+
intr += ".";
235+
intr += elemType.isInteger(32) ? "d" : "q";
236+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
237+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
238+
intr += "." + std::to_string(opBitWidth);
239+
return intr;
240+
}
241+
}];
275242
}
276243

277244
//----------------------------------------------------------------------------//
@@ -299,7 +266,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
299266

300267
Example:
301268
```mlir
302-
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
269+
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303270
```
304271
}];
305272
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
@@ -309,36 +276,18 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
309276
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310277
let assemblyFormat =
311278
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312-
}
313-
314-
def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315-
AllTypesMatch<["a", "b"]>,
316-
AllTypesMatch<["src", "res"]>],
317-
/*extension=*/"bf16"> {
318-
let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319-
VectorOfLengthAndType<[8], [BF16]>:$a,
320-
VectorOfLengthAndType<[8], [BF16]>:$b);
321-
let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322-
}
323279

324-
def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325-
AllTypesMatch<["a", "b"]>,
326-
AllTypesMatch<["src", "res"]>],
327-
/*extension=*/"bf16"> {
328-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329-
VectorOfLengthAndType<[16], [BF16]>:$a,
330-
VectorOfLengthAndType<[16], [BF16]>:$b);
331-
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332-
}
333-
334-
def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335-
AllTypesMatch<["a", "b"]>,
336-
AllTypesMatch<["src", "res"]>],
337-
/*extension=*/"bf16"> {
338-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339-
VectorOfLengthAndType<[32], [BF16]>:$a,
340-
VectorOfLengthAndType<[32], [BF16]>:$b);
341-
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
280+
let extraClassDeclaration = [{
281+
/// Return LLVM intrinsic function name matching op variant.
282+
std::string getIntrinsicName() {
283+
std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
284+
VectorType vecType = getSrc().getType();
285+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
286+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
287+
intr += "." + std::to_string(opBitWidth);
288+
return intr;
289+
}
290+
}];
342291
}
343292

344293
//----------------------------------------------------------------------------//
@@ -367,18 +316,18 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
367316
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368317
let assemblyFormat =
369318
"$a attr-dict `:` type($a) `->` type($dst)";
370-
}
371319

372-
def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373-
/*extension=*/"bf16"> {
374-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375-
let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376-
}
377-
378-
def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379-
/*extension=*/"bf16"> {
380-
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381-
let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
320+
let extraClassDeclaration = [{
321+
/// Return LLVM intrinsic function name matching op variant.
322+
std::string getIntrinsicName() {
323+
std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
324+
VectorType vecType = getA().getType();
325+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
326+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
327+
intr += "." + std::to_string(opBitWidth);
328+
return intr;
329+
}
330+
}];
382331
}
383332

384333
//===----------------------------------------------------------------------===//
@@ -395,12 +344,6 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
395344
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
396345
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
397346

398-
// Intrinsic operation used during lowering to LLVM IR.
399-
class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
400-
LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
401-
"x86_avx_" # !subst(".", "_", mnemonic),
402-
[], [], traits, numResults>;
403-
404347
//----------------------------------------------------------------------------//
405348
// AVX Rsqrt
406349
//----------------------------------------------------------------------------//
@@ -410,11 +353,13 @@ def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
410353
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
411354
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
412355
let assemblyFormat = "$a attr-dict `:` type($a)";
413-
}
414356

415-
def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
416-
SameOperandsAndResultType]> {
417-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
357+
let extraClassDeclaration = [{
358+
/// Return LLVM intrinsic function name matching op variant.
359+
std::string getIntrinsicName() {
360+
return "llvm.x86.avx.rsqrt.ps.256";
361+
}
362+
}];
418363
}
419364

420365
//----------------------------------------------------------------------------//
@@ -443,13 +388,13 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
443388
VectorOfLengthAndType<[8], [F32]>:$b);
444389
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
445390
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
446-
}
447391

448-
def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
449-
AllTypesMatch<["a", "b", "res"]>]> {
450-
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
451-
VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
452-
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
392+
let extraClassDeclaration = [{
393+
/// Return LLVM intrinsic function name matching op variant.
394+
std::string getIntrinsicName() {
395+
return "llvm.x86.avx.dp.ps.256";
396+
}
397+
}];
453398
}
454399

455400
#endif // X86VECTOR_OPS

mlir/include/mlir/Target/LLVMIR/Dialect/All.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
3030
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
3131
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
32-
#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h"
3332

3433
namespace mlir {
3534
class DialectRegistry;
@@ -50,7 +49,6 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
5049
registerROCDLDialectTranslation(registry);
5150
registerSPIRVDialectTranslation(registry);
5251
registerVCIXDialectTranslation(registry);
53-
registerX86VectorDialectTranslation(registry);
5452

5553
// Extension required for translating GPU offloading Ops.
5654
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);

mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h

Lines changed: 0 additions & 32 deletions
This file was deleted.

mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
22
AVXTranspose.cpp
33
LegalizeForLLVMExport.cpp
44

5-
DEPENDS
6-
MLIRX86VectorConversionsIncGen
7-
85
LINK_LIBS PUBLIC
96
MLIRArithDialect
107
MLIRX86VectorDialect

0 commit comments

Comments
 (0)