-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions - Fix #136830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added @rengolin @adam-smnk please do have a look. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: None (arun-thmn) ChangesQuick fix for the PR: #135143 which failed building on Full diff: https://github.com/llvm/llvm-project/pull/136830.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 5be0d92db4630..126fa0e352656 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}
@@ -404,8 +404,127 @@ def DotOp : AVX_LowOp<"dot", [Pure,
}
}];
let extraClassDeclaration = [{
- SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
}];
}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
+//----------------------------------------------------------------------------//
+
+def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneebf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
+ memory locations starting at location `__A` to packed single-precision
+ (32-bit) floating-point elements, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vcvtneobf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+}
+
+//----------------------------------------------------------------------------//
+// AVX: Convert BF16 to F32 and broadcast into packed F32
+//----------------------------------------------------------------------------//
+
+def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
+ DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
+ let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
+ let description = [{
+ #### From the Intel Intrinsics Guide:
+
+ Convert scalar BF16 (16-bit) floating-point element stored at memory locations
+ starting at location `__A` to a single-precision (32-bit) floating-point,
+ broadcast it to packed single-precision (32-bit) floating-point elements,
+ and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ ```
+ }];
+ let arguments = (ins AnyMemRef:$a);
+ let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
+ let assemblyFormat =
+ "$a attr-dict`:` type($a)`->` type($dst)";
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getIntrinsicName() {
+ std::string intr = "llvm.x86.vbcstnebf162ps";
+ VectorType vecType = getDst().getType();
+ unsigned elemBitWidth = vecType.getElementTypeBitWidth();
+ unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
+ intr += std::to_string(opBitWidth);
+ return intr;
+ }
+ }];
+
+ let extraClassDeclaration = [{
+ SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+ }];
+
+}
+
#endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
index 7bcf4c69b0a6c..308adfa5b9021 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h
@@ -14,6 +14,8 @@
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 98d5ca70b4a7d..5176f4a447b6e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getIntrinsicOperands",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
/*methodBody=*/"",
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
>,
diff --git a/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt
index d24617f037b13..5499d93d5f924 100644
--- a/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRX86VectorDialect
LINK_LIBS PUBLIC
MLIRIR
+ MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 5bb4dcfd60d83..f5e5070c74f8f 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,6 +31,26 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
+static SmallVector<Value>
+getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
+ RewriterBase &rewriter,
+ const LLVMTypeConverter &typeConverter) {
+ SmallVector<Value> operands;
+ auto opType = memrefVal.getType();
+
+ Type llvmStructType = typeConverter.convertType(opType);
+ Value llvmStruct =
+ rewriter
+ .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
+ .getResult(0);
+ MemRefDescriptor memRefDescriptor(llvmStruct);
+
+ Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
+ operands.push_back(ptr);
+
+ return operands;
+}
+
LogicalResult x86vector::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
@@ -45,8 +65,8 @@ LogicalResult x86vector::MaskCompressOp::verify() {
return success();
}
-SmallVector<Value>
-x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
+SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
auto loc = getLoc();
auto opType = getA().getType();
@@ -64,7 +84,8 @@ x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
}
SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
+x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
+ const LLVMTypeConverter &typeConverter) {
SmallVector<Value> operands(getOperands());
// Dot product of all elements, broadcasted to all elements.
Value scale =
@@ -74,5 +95,22 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
return operands;
}
+SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
+SmallVector<Value>
+x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
+ RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+ return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index c0c7f61f55f88..d2297554a1012 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
PatternRewriter &rewriter) const override {
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
- op.getIntrinsicOperands(rewriter), typeConverter,
- rewriter);
+ op.getIntrinsicOperands(rewriter, typeConverter),
+ typeConverter, rewriter);
}
private:
@@ -114,7 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
- Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
- DotOp>();
+ target.addIllegalOp<
+ MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
+ CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
+ CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
}
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index df0be7bce83be..93b304c44de8e 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
+func.func @avxbf16_bsct_bf16_to_f32_packed_128(
+ %a: memref<1xbf16>) -> vector<4xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
+func.func @avxbf16_bsct_bf16_to_f32_packed_256(
+ %a: memref<1xbf16>) -> vector<8xf32>
+{
+ // CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 0d00448c63da8..b783cc869b981 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
+func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
+ // CHECK-SAME: memref<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
+func.func @avxbf16_bcst_bf16_to_f32_128(
+ %a: memref<1xbf16>) -> vector<4xf32>
+{
+ // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xbf16> -> vector<4xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
+func.func @avxbf16_bcst_bf16_to_f32_256(
+ %a: memref<1xbf16>) -> vector<8xf32>
+{
+ // CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
+ // CHECK-SAME: memref<1xbf16> -> vector<8xf32>
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir
index 85dad36334b1d..a8bc180d1d0ac 100644
--- a/mlir/test/Target/LLVMIR/x86vector.mlir
+++ b/mlir/test/Target/LLVMIR/x86vector.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm \
+// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm -reconcile-unrealized-casts \
// RUN: | mlir-translate --mlir-to-llvmir \
// RUN: | FileCheck %s
@@ -109,6 +109,60 @@ func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
return %0 : vector<16xbf16>
}
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneebf162ps128
+func.func @LLVM_x86_avxbf16_vcvtneebf162ps128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneebf162ps128(
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneebf162ps256
+func.func @LLVM_x86_avxbf16_vcvtneebf162ps256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneebf162ps256(
+ %0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vcvtneobf162ps128
+func.func @LLVM_x86_avxbf16_vcvtneobf162ps128(
+ %a: memref<8xbf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vcvtneobf162ps128(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vcvtneobf162ps256
+func.func @LLVM_x86_avxbf16_vcvtneobf162ps256(
+ %a: memref<16xbf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vcvtneobf162ps256(
+ %0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: define <4 x float> @LLVM_x86_avxbf16_vbcstnebf162ps128
+func.func @LLVM_x86_avxbf16_vbcstnebf162ps128(
+ %a: memref<1xbf16>) -> vector<4xf32>
+{
+ // CHECK: call <4 x float> @llvm.x86.vbcstnebf162ps128(
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: define <8 x float> @LLVM_x86_avxbf16_vbcstnebf162ps256
+func.func @LLVM_x86_avxbf16_vbcstnebf162ps256(
+ %a: memref<1xbf16>) -> vector<8xf32>
+{
+ // CHECK: call <8 x float> @llvm.x86.vbcstnebf162ps256(
+ %0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
{
|
Hi @jplehr, Can you please check and let me know the builds are succeeding on your bots. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extra linked lib should do the trick - LGTM otherwise.
Let's wait first to double check if there's potentially anything else missing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm currently on mobile and cannot monitor, but I'm confident that the additional library will address the issue.
I think we can land this and see what happens.
) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: #136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…ns (llvm#136830) Quick fix for the PR: llvm#135143 which failed building on `amd` and `arm` bots build. See the logs in the above PR for the errors.
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…ns (llvm#136830) Quick fix for the PR: llvm#135143 which failed building on `amd` and `arm` bots build. See the logs in the above PR for the errors.
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…ns (llvm#136830) Quick fix for the PR: llvm#135143 which failed building on `amd` and `arm` bots build. See the logs in the above PR for the errors.
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…tions (#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm/llvm-project#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
…#137917) Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
Quick fix for the PR: #135143 which failed building on
amd
andarm
bots build. See the logs in the above PR for the errors.