Skip to content

Commit 87782b2

Browse files
authored
[mlir][x86vector] AVX512-BF16 Dot op (#124800)
Adds AVX512 bf16 dot-product operation and defines lowering to LLVM intrinsics. AVX512 intrinsic operation definition is extended with an optional extension field that allows specifying necessary LLVM mnemonic suffix e.g., `"bf16"` for `x86_avx512bf16_` intrinsics.
1 parent 9534d27 commit 87782b2

File tree

6 files changed

+249
-6
lines changed

6 files changed

+249
-6
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,20 @@ class AVX512_Op<string mnemonic, list<Trait> traits = []> :
3535
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
3636

3737
// Intrinsic operation used during lowering to LLVM IR.
38-
class AVX512_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
38+
class AVX512_IntrOp<string mnemonic, int numResults,
39+
list<Trait> traits = [],
40+
string extension = ""> :
3941
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
40-
"x86_avx512_" # !subst(".", "_", mnemonic),
42+
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
4143
[], [], traits, numResults>;
4244

4345
// Defined by first result overload. May have to be extended for other
4446
// instructions in the future.
4547
class AVX512_IntrOverloadedOp<string mnemonic,
46-
list<Trait> traits = []> :
48+
list<Trait> traits = [],
49+
string extension = ""> :
4750
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
48-
"x86_avx512_" # !subst(".", "_", mnemonic),
51+
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
4952
/*list<int> overloadedResults=*/[0],
5053
/*list<int> overloadedOperands=*/[],
5154
traits, /*numResults=*/1>;
@@ -271,6 +274,73 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
271274
VectorOfLengthAndType<[8], [I64]>:$b);
272275
}
273276

277+
//----------------------------------------------------------------------------//
278+
// Dot BF16
279+
//----------------------------------------------------------------------------//
280+
281+
def DotBF16Op : AVX512_Op<"dot", [Pure,
282+
AllTypesMatch<["a", "b"]>,
283+
AllTypesMatch<["src", "dst"]>,
284+
TypesMatchWith<"`a` has twice an many elements as `src`",
285+
"src", "a",
286+
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
287+
"BFloat16Type::get($_self.getContext()))">]> {
288+
let summary = "Dot BF16 op";
289+
let description = [{
290+
The `dot` op is an AVX512-BF16 specific op that can lower to the proper
291+
LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
292+
vectors it is applied to.
293+
294+
#### From the Intel Intrinsics Guide:
295+
296+
Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
297+
accumulating the intermediate single-precision (32-bit) floating-point
298+
elements with elements in `src`, and store the results in `dst`.
299+
300+
Example:
301+
```mlir
302+
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
303+
```
304+
}];
305+
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
306+
VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
307+
VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
308+
);
309+
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
310+
let assemblyFormat =
311+
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
312+
}
313+
314+
def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
315+
AllTypesMatch<["a", "b"]>,
316+
AllTypesMatch<["src", "res"]>],
317+
/*extension=*/"bf16"> {
318+
let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
319+
VectorOfLengthAndType<[8], [BF16]>:$a,
320+
VectorOfLengthAndType<[8], [BF16]>:$b);
321+
let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
322+
}
323+
324+
def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
325+
AllTypesMatch<["a", "b"]>,
326+
AllTypesMatch<["src", "res"]>],
327+
/*extension=*/"bf16"> {
328+
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
329+
VectorOfLengthAndType<[16], [BF16]>:$a,
330+
VectorOfLengthAndType<[16], [BF16]>:$b);
331+
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
332+
}
333+
334+
def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
335+
AllTypesMatch<["a", "b"]>,
336+
AllTypesMatch<["src", "res"]>],
337+
/*extension=*/"bf16"> {
338+
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
339+
VectorOfLengthAndType<[32], [BF16]>:$a,
340+
VectorOfLengthAndType<[32], [BF16]>:$b);
341+
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
342+
}
343+
274344
//===----------------------------------------------------------------------===//
275345
// AVX op definitions
276346
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,47 @@ struct MaskCompressOpConversion
9090
}
9191
};
9292

93+
struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
94+
using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
95+
96+
LogicalResult
97+
matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
98+
ConversionPatternRewriter &rewriter) const override {
99+
auto typeA = dyn_cast<VectorType>(op.getA().getType());
100+
unsigned elemBitWidth = typeA.getElementTypeBitWidth();
101+
unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
102+
103+
auto opType = adaptor.getSrc().getType();
104+
auto opSrc = adaptor.getSrc();
105+
auto opA = adaptor.getA();
106+
auto opB = adaptor.getB();
107+
108+
switch (opBitWidth) {
109+
case 128: {
110+
rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
111+
opB);
112+
break;
113+
}
114+
case 256: {
115+
rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
116+
opB);
117+
break;
118+
}
119+
case 512: {
120+
rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
121+
opB);
122+
break;
123+
}
124+
default: {
125+
return rewriter.notifyMatchFailure(op,
126+
"unsupported AVX512-BF16 dot variant");
127+
}
128+
}
129+
130+
return success();
131+
}
132+
};
133+
93134
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
94135
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
95136

@@ -161,15 +202,19 @@ using Registry = RegistryImpl<
161202
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
162203
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
163204
Registry::registerPatterns(converter, patterns);
164-
patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165-
converter);
205+
patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
206+
DotOpConversion>(converter);
166207
}
167208

168209
void mlir::configureX86VectorLegalizeForExportTarget(
169210
LLVMConversionTarget &target) {
170211
Registry::configureTarget(target);
171212
target.addLegalOp<MaskCompressIntrOp>();
172213
target.addIllegalOp<MaskCompressOp>();
214+
target.addLegalOp<DotBF16Ps128IntrOp>();
215+
target.addLegalOp<DotBF16Ps256IntrOp>();
216+
target.addLegalOp<DotBF16Ps512IntrOp>();
217+
target.addIllegalOp<DotBF16Op>();
173218
target.addLegalOp<RsqrtIntrOp>();
174219
target.addIllegalOp<RsqrtOp>();
175220
target.addLegalOp<DotIntrOp>();
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
3+
// RUN: -reconcile-unrealized-casts | \
4+
// RUN: mlir-translate --mlir-to-llvmir | \
5+
// RUN: llc -mcpu=sapphirerapids | \
6+
// RUN: FileCheck %s
7+
8+
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
9+
%b: vector<8xbf16>) -> vector<4xf32> {
10+
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
11+
return %0 : vector<4xf32>
12+
}
13+
// CHECK-LABEL: avx512bf16_dot_128:
14+
// CHECK: vdpbf16ps{{.*}}%xmm
15+
16+
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
17+
%b: vector<16xbf16>) -> vector<8xf32> {
18+
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
19+
return %0 : vector<8xf32>
20+
}
21+
// CHECK-LABEL: avx512bf16_dot_256:
22+
// CHECK: vdpbf16ps{{.*}}%ymm
23+
24+
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
25+
%b: vector<32xbf16>) -> vector<16xf32> {
26+
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
27+
return %0 : vector<16xf32>
28+
}
29+
// CHECK-LABEL: avx512bf16_dot_512:
30+
// CHECK: vdpbf16ps{{.*}}%zmm

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,33 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
4343
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
4444
}
4545

46+
// CHECK-LABEL: func @avx512bf16_dot_128
47+
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
48+
%b: vector<8xbf16>) -> (vector<4xf32>)
49+
{
50+
// CHECK: x86vector.avx512.intr.dpbf16ps.128
51+
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
52+
return %0 : vector<4xf32>
53+
}
54+
55+
// CHECK-LABEL: func @avx512bf16_dot_256
56+
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
57+
%b: vector<16xbf16>) -> (vector<8xf32>)
58+
{
59+
// CHECK: x86vector.avx512.intr.dpbf16ps.256
60+
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
61+
return %0 : vector<8xf32>
62+
}
63+
64+
// CHECK-LABEL: func @avx512bf16_dot_512
65+
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
66+
%b: vector<32xbf16>) -> (vector<16xf32>)
67+
{
68+
// CHECK: x86vector.avx512.intr.dpbf16ps.512
69+
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
70+
return %0 : vector<16xf32>
71+
}
72+
4673
// CHECK-LABEL: func @avx_rsqrt
4774
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
4875
{

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,33 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
4747
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
4848
}
4949

50+
// CHECK-LABEL: func @avx512bf16_dot_128
51+
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
52+
%b: vector<8xbf16>) -> (vector<4xf32>)
53+
{
54+
// CHECK: x86vector.avx512.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
55+
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
56+
return %0 : vector<4xf32>
57+
}
58+
59+
// CHECK-LABEL: func @avx512bf16_dot_256
60+
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
61+
%b: vector<16xbf16>) -> (vector<8xf32>)
62+
{
63+
// CHECK: x86vector.avx512.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
64+
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
65+
return %0 : vector<8xf32>
66+
}
67+
68+
// CHECK-LABEL: func @avx512bf16_dot_512
69+
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
70+
%b: vector<32xbf16>) -> (vector<16xf32>)
71+
{
72+
// CHECK: x86vector.avx512.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
73+
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
74+
return %0 : vector<16xf32>
75+
}
76+
5077
// CHECK-LABEL: func @avx_rsqrt
5178
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
5279
{

mlir/test/Target/LLVMIR/x86vector.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,54 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
6060
llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
6161
}
6262

63+
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
64+
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
65+
%arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
66+
) -> vector<4xf32>
67+
{
68+
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
69+
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
70+
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
71+
llvm.return %0 : vector<4xf32>
72+
}
73+
74+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
75+
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
76+
%arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
77+
) -> vector<8xf32>
78+
{
79+
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
80+
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
81+
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
82+
llvm.return %0 : vector<8xf32>
83+
}
84+
85+
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
86+
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
87+
%arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
88+
) -> vector<16xf32>
89+
{
90+
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
91+
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
92+
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
93+
llvm.return %0 : vector<16xf32>
94+
}
95+
6396
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
6497
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
6598
{
6699
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
67100
%0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>)
68101
llvm.return %0 : vector<8xf32>
69102
}
103+
104+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
105+
llvm.func @LLVM_x86_avx_dp_ps_256(
106+
%arg0: vector<8xf32>, %arg1: vector<8xf32>
107+
) -> vector<8xf32>
108+
{
109+
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
110+
%0 = llvm.mlir.constant(-1 : i8) : i8
111+
%1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
112+
llvm.return %1 : vector<8xf32>
113+
}

0 commit comments

Comments
 (0)