Skip to content

[mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions #135143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,110 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}];
}


//----------------------------------------------------------------------------//
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
//----------------------------------------------------------------------------//

def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneebf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
memory locations starting at location `__A` to packed single-precision
(32-bit) floating-point elements, and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vcvtneobf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

//----------------------------------------------------------------------------//
// AVX: Convert BF16 to F32 and broadcast into packed F32
//----------------------------------------------------------------------------//

def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
let description = [{
#### From the Intel Intrinsics Guide:

Convert scalar BF16 (16-bit) floating-point element stored at memory locations
starting at location `__A` to a single-precision (32-bit) floating-point,
broadcast it to packed single-precision (32-bit) floating-point elements,
and store the results in `dst`.

Example:
```mlir
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
```
}];
let arguments = (ins LLVM_AnyPointer:$a);
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
let assemblyFormat =
"$a attr-dict`:` type($a)`->` type($dst)";

let extraClassDefinition = [{
std::string $cppClass::getIntrinsicName() {
std::string intr = "llvm.x86.vbcstnebf162ps";
VectorType vecType = getDst().getType();
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
intr += std::to_string(opBitWidth);
return intr;
}
}];
}

#endif // X86VECTOR_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(

void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
DotOp>();
target.addIllegalOp<
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
}
22 changes: 22 additions & 0 deletions mlir/test/Dialect/X86Vector/bcst-avx-bf16-to-f32-packed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// REQUIRES: target=x86{{.*}}

// RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sierraforest | \
// RUN: FileCheck %s

func.func @avxbf16_bcst_bf16_to_f32_packed_128(%arg0: !llvm.ptr) -> vector<4xf32> {
%0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_128:
// CHECK: vbcstnebf162ps{{.*}}%xmm

func.func @avxbf16_bcst_bf16_to_f32_packed_256(%arg0: !llvm.ptr) -> vector<8xf32> {
%0 = x86vector.avx.bcst.bf16_to_f32.packed %arg0 : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}
// CHECK-LABEL: avxbf16_bcst_bf16_to_f32_packed_256:
// CHECK: vbcstnebf162ps{{.*}}%ymm
48 changes: 48 additions & 0 deletions mlir/test/Dialect/X86Vector/cvt-packed-avx-bf16-to-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// REQUIRES: target=x86{{.*}}

// RUN: mlir-opt %s \
// RUN: -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-translate --mlir-to-llvmir | \
// RUN: llc -mcpu=sierraforest | \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the value of checking assembly? Shouldn't this be tested in LLVM instead?

I appreciate the desire for more complete, e2e testing, but these things come at a cost and I'd rather for us to focus on the bare minimum. Especially for things that are definitely tested in LLVM (i.e. lowering LLVM intrinsics to ASM).

Hopefully this make sense :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

// RUN: FileCheck %s

func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
%intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
%0 = arith.index_cast %intptr : index to i32
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr
%2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
return %2 : vector<4xf32>
}
// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_128:
// CHECK: vcvtneebf162ps{{.*}}%xmm

func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
%intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
%0 = arith.index_cast %intptr : index to i32
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr
%2 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
return %2 : vector<8xf32>
}
// CHECK-LABEL: avxbf16_cvt_packed_even_indexed_bf16_to_f32_256:
// CHECK: vcvtneebf162ps{{.*}}%ymm

func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(%arg0: memref<8xbf16>) -> vector<4xf32> {
%intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8xbf16> -> index
%0 = arith.index_cast %intptr : index to i32
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr
%2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<4xf32>
return %2 : vector<4xf32>
}
// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128:
// CHECK: vcvtneobf162ps{{.*}}%xmm

func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(%arg0: memref<16xbf16>) -> vector<8xf32> {
%intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<16xbf16> -> index
%0 = arith.index_cast %intptr : index to i32
%1 = llvm.inttoptr %0 : i32 to !llvm.ptr
%2 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %1 : !llvm.ptr -> vector<8xf32>
return %2 : vector<8xf32>
}
// CHECK-LABEL: avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256:
// CHECK: vcvtneobf162ps{{.*}}%ymm
54 changes: 54 additions & 0 deletions mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/X86Vector/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
func.func @avxbf16_bcst_bf16_to_f32_128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
func.func @avxbf16_bcst_bf16_to_f32_256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Target/LLVMIR/x86vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,60 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256
func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128
func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256
func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128
func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
%a: !llvm.ptr) -> vector<4xf32>
{
// CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
return %0 : vector<4xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256
func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
%a: !llvm.ptr) -> vector<8xf32>
{
// CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
return %0 : vector<8xf32>
}

// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
Expand Down