Skip to content

Commit 2b71df5

Browse files
authored
[mlir][x86vector] AVX512-BF16 Convert packed F32 to BF16 (#125685)
Adds AVX512 bf16 conversion from packed f32 to bf16 elements. Tests are slightly refactored to better follow file's convention.
1 parent 2fdb26d commit 2b71df5

File tree

6 files changed

+171
-11
lines changed

6 files changed

+171
-11
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,46 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
341341
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
342342
}
343343

344+
//----------------------------------------------------------------------------//
345+
// Convert packed F32 to packed BF16
346+
//----------------------------------------------------------------------------//
347+
348+
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
349+
AllElementCountsMatch<["a", "dst"]>]> {
350+
let summary = "Convert packed F32 to packed BF16 Data.";
351+
let description = [{
352+
The `convert_f32_to_bf16` op is an AVX512-BF16 specific op that can lower
353+
to the proper LLVMAVX512BF16 operation `llvm.cvtneps2bf16` depending on
354+
the width of MLIR vectors it is applied to.
355+
356+
#### From the Intel Intrinsics Guide:
357+
358+
Convert packed single-precision (32-bit) floating-point elements in `a` to
359+
packed BF16 (16-bit) floating-point elements, and store the results in `dst`.
360+
361+
Example:
362+
```mlir
363+
%dst = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
364+
```
365+
}];
366+
let arguments = (ins VectorOfLengthAndType<[8, 16], [F32]>:$a);
367+
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
368+
let assemblyFormat =
369+
"$a attr-dict `:` type($a) `->` type($dst)";
370+
}
371+
372+
def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
373+
/*extension=*/"bf16"> {
374+
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
375+
let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
376+
}
377+
378+
def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
379+
/*extension=*/"bf16"> {
380+
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
381+
let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
382+
}
383+
344384
//===----------------------------------------------------------------------===//
345385
// AVX op definitions
346386
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,39 @@ struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
131131
}
132132
};
133133

134+
struct CvtPackedF32ToBF16Conversion
135+
: public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
136+
using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
137+
138+
LogicalResult
139+
matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
140+
ConversionPatternRewriter &rewriter) const override {
141+
auto typeA = dyn_cast<VectorType>(op.getA().getType());
142+
unsigned elemBitWidth = typeA.getElementTypeBitWidth();
143+
unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
144+
145+
auto opType = op.getDst().getType();
146+
auto opA = op.getA();
147+
148+
switch (opBitWidth) {
149+
case 256: {
150+
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
151+
break;
152+
}
153+
case 512: {
154+
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
155+
break;
156+
}
157+
default: {
158+
return rewriter.notifyMatchFailure(
159+
op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
160+
}
161+
}
162+
163+
return success();
164+
}
165+
};
166+
134167
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
135168
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
136169

@@ -202,8 +235,10 @@ using Registry = RegistryImpl<
202235
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
203236
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
204237
Registry::registerPatterns(converter, patterns);
205-
patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
206-
DotOpConversion>(converter);
238+
patterns
239+
.add<MaskCompressOpConversion, DotBF16OpConversion,
240+
CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
241+
converter);
207242
}
208243

209244
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -215,6 +250,9 @@ void mlir::configureX86VectorLegalizeForExportTarget(
215250
target.addLegalOp<DotBF16Ps256IntrOp>();
216251
target.addLegalOp<DotBF16Ps512IntrOp>();
217252
target.addIllegalOp<DotBF16Op>();
253+
target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
254+
target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
255+
target.addIllegalOp<CvtPackedF32ToBF16Op>();
218256
target.addLegalOp<RsqrtIntrOp>();
219257
target.addIllegalOp<RsqrtOp>();
220258
target.addLegalOp<DotIntrOp>();
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// REQUIRES: target=x86{{.*}}
2+
3+
// RUN: mlir-opt %s \
4+
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
5+
// RUN: -reconcile-unrealized-casts | \
6+
// RUN: mlir-translate --mlir-to-llvmir | \
7+
// RUN: llc -mcpu=sapphirerapids | \
8+
// RUN: FileCheck %s
9+
10+
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
11+
%a: vector<8xf32>) -> vector<8xbf16> {
12+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
13+
return %0 : vector<8xbf16>
14+
}
15+
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_256:
16+
// CHECK: vcvtneps2bf16{{.*}}%xmm
17+
18+
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
19+
%a: vector<16xf32>) -> vector<16xbf16> {
20+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
21+
return %0 : vector<16xbf16>
22+
}
23+
// CHECK-LABEL: avx512bf16_cvt_packed_f32_to_bf16_512:
24+
// CHECK: vcvtneps2bf16{{.*}}%ymm

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,24 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
7070
return %0 : vector<16xf32>
7171
}
7272

73+
// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
74+
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
75+
%a: vector<8xf32>) -> (vector<8xbf16>)
76+
{
77+
// CHECK: x86vector.avx512.intr.cvtneps2bf16.256
78+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
79+
return %0 : vector<8xbf16>
80+
}
81+
82+
// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
83+
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
84+
%a: vector<16xf32>) -> (vector<16xbf16>)
85+
{
86+
// CHECK: x86vector.avx512.intr.cvtneps2bf16.512
87+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
88+
return %0 : vector<16xbf16>
89+
}
90+
7391
// CHECK-LABEL: func @avx_rsqrt
7492
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
7593
{

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,26 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
7474
return %0 : vector<16xf32>
7575
}
7676

77+
// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_256
78+
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
79+
%a: vector<8xf32>) -> (vector<8xbf16>)
80+
{
81+
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
82+
// CHECK-SAME: vector<8xf32> -> vector<8xbf16>
83+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
84+
return %0 : vector<8xbf16>
85+
}
86+
87+
// CHECK-LABEL: func @avx512bf16_cvt_packed_f32_to_bf16_512
88+
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
89+
%a: vector<16xf32>) -> (vector<16xbf16>)
90+
{
91+
// CHECK: x86vector.avx512.cvt.packed.f32_to_bf16 {{.*}} :
92+
// CHECK-SAME: vector<16xf32> -> vector<16xbf16>
93+
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
94+
return %0 : vector<16xbf16>
95+
}
96+
7797
// CHECK-LABEL: func @avx_rsqrt
7898
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
7999
{

mlir/test/Target/LLVMIR/x86vector.mlir

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,37 +62,57 @@ llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
6262

6363
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
6464
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
65-
%arg0: vector<4xf32>, %arg1: vector<8xbf16>, %arg2: vector<8xbf16>
65+
%src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
6666
) -> vector<4xf32>
6767
{
6868
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
69-
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%arg0, %arg1, %arg2)
69+
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
7070
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
7171
llvm.return %0 : vector<4xf32>
7272
}
7373

7474
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
7575
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
76-
%arg0: vector<8xf32>, %arg1: vector<16xbf16>, %arg2: vector<16xbf16>
76+
%src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
7777
) -> vector<8xf32>
7878
{
7979
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
80-
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%arg0, %arg1, %arg2)
80+
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
8181
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
8282
llvm.return %0 : vector<8xf32>
8383
}
8484

8585
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
8686
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
87-
%arg0: vector<16xf32>, %arg1: vector<32xbf16>, %arg2: vector<32xbf16>
87+
%src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
8888
) -> vector<16xf32>
8989
{
9090
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
91-
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%arg0, %arg1, %arg2)
91+
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
9292
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
9393
llvm.return %0 : vector<16xf32>
9494
}
9595

96+
// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
97+
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
98+
%a: vector<8xf32>) -> vector<8xbf16>
99+
{
100+
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
101+
%0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
102+
: (vector<8xf32>) -> vector<8xbf16>
103+
llvm.return %0 : vector<8xbf16>
104+
}
105+
106+
// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
107+
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
108+
%a: vector<16xf32>) -> vector<16xbf16>
109+
{
110+
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
111+
%0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
112+
: (vector<16xf32>) -> vector<16xbf16>
113+
llvm.return %0 : vector<16xbf16>
114+
}
115+
96116
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
97117
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
98118
{
@@ -103,11 +123,11 @@ llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
103123

104124
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
105125
llvm.func @LLVM_x86_avx_dp_ps_256(
106-
%arg0: vector<8xf32>, %arg1: vector<8xf32>
126+
%a: vector<8xf32>, %b: vector<8xf32>
107127
) -> vector<8xf32>
108128
{
109129
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
110-
%0 = llvm.mlir.constant(-1 : i8) : i8
111-
%1 = "x86vector.avx.intr.dp.ps.256"(%arg0, %arg1, %0) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
130+
%c = llvm.mlir.constant(-1 : i8) : i8
131+
%1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
112132
llvm.return %1 : vector<8xf32>
113133
}

0 commit comments

Comments
 (0)