Skip to content

Commit 7e86afa

Browse files
authored
Revert "[mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions" (#136781)
Reverts #135143 This broke multiple bots, see PR.
1 parent 0797f70 commit 7e86afa

File tree

8 files changed

+12
-340
lines changed

8 files changed

+12
-340
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 2 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
8383
}
8484
}];
8585
let extraClassDeclaration = [{
86-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
86+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
8787
}];
8888
}
8989

@@ -404,127 +404,8 @@ def DotOp : AVX_LowOp<"dot", [Pure,
404404
}
405405
}];
406406
let extraClassDeclaration = [{
407-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
407+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
408408
}];
409409
}
410410

411-
412-
//----------------------------------------------------------------------------//
413-
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
414-
//----------------------------------------------------------------------------//
415-
416-
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
417-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418-
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
419-
let description = [{
420-
#### From the Intel Intrinsics Guide:
421-
422-
Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423-
memory locations starting at location `__A` to packed single-precision
424-
(32-bit) floating-point elements, and store the results in `dst`.
425-
426-
Example:
427-
```mlir
428-
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
429-
```
430-
}];
431-
let arguments = (ins AnyMemRef:$a);
432-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433-
let assemblyFormat =
434-
"$a attr-dict`:` type($a)`->` type($dst)";
435-
436-
let extraClassDefinition = [{
437-
std::string $cppClass::getIntrinsicName() {
438-
std::string intr = "llvm.x86.vcvtneebf162ps";
439-
VectorType vecType = getDst().getType();
440-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
442-
intr += std::to_string(opBitWidth);
443-
return intr;
444-
}
445-
}];
446-
447-
let extraClassDeclaration = [{
448-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449-
}];
450-
}
451-
452-
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
453-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454-
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
455-
let description = [{
456-
#### From the Intel Intrinsics Guide:
457-
458-
Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
459-
memory locations starting at location `__A` to packed single-precision
460-
(32-bit) floating-point elements, and store the results in `dst`.
461-
462-
Example:
463-
```mlir
464-
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
465-
```
466-
}];
467-
let arguments = (ins AnyMemRef:$a);
468-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469-
let assemblyFormat =
470-
"$a attr-dict`:` type($a)`->` type($dst)";
471-
472-
let extraClassDefinition = [{
473-
std::string $cppClass::getIntrinsicName() {
474-
std::string intr = "llvm.x86.vcvtneobf162ps";
475-
VectorType vecType = getDst().getType();
476-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
478-
intr += std::to_string(opBitWidth);
479-
return intr;
480-
}
481-
}];
482-
483-
let extraClassDeclaration = [{
484-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
485-
}];
486-
}
487-
488-
//----------------------------------------------------------------------------//
489-
// AVX: Convert BF16 to F32 and broadcast into packed F32
490-
//----------------------------------------------------------------------------//
491-
492-
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
493-
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494-
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
495-
let description = [{
496-
#### From the Intel Intrinsics Guide:
497-
498-
Convert scalar BF16 (16-bit) floating-point element stored at memory locations
499-
starting at location `__A` to a single-precision (32-bit) floating-point,
500-
broadcast it to packed single-precision (32-bit) floating-point elements,
501-
and store the results in `dst`.
502-
503-
Example:
504-
```mlir
505-
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
506-
```
507-
}];
508-
let arguments = (ins AnyMemRef:$a);
509-
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
510-
let assemblyFormat =
511-
"$a attr-dict`:` type($a)`->` type($dst)";
512-
513-
let extraClassDefinition = [{
514-
std::string $cppClass::getIntrinsicName() {
515-
std::string intr = "llvm.x86.vbcstnebf162ps";
516-
VectorType vecType = getDst().getType();
517-
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518-
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
519-
intr += std::to_string(opBitWidth);
520-
return intr;
521-
}
522-
}];
523-
524-
let extraClassDeclaration = [{
525-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526-
}];
527-
528-
}
529-
530411
#endif // X86VECTOR_OPS

mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17-
#include "mlir/Conversion/LLVMCommon/Pattern.h"
18-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1917
#include "mlir/IR/BuiltinTypes.h"
2018
#include "mlir/IR/Dialect.h"
2119
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
5858
}],
5959
/*retType=*/"SmallVector<Value>",
6060
/*methodName=*/"getIntrinsicOperands",
61-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
61+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
6262
/*methodBody=*/"",
6363
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
6464
>,

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,6 @@ void x86vector::X86VectorDialect::initialize() {
3131
>();
3232
}
3333

34-
static SmallVector<Value>
35-
getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
36-
RewriterBase &rewriter,
37-
const LLVMTypeConverter &typeConverter) {
38-
SmallVector<Value> operands;
39-
auto opType = memrefVal.getType();
40-
41-
Type llvmStructType = typeConverter.convertType(opType);
42-
Value llvmStruct =
43-
rewriter
44-
.create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45-
.getResult(0);
46-
MemRefDescriptor memRefDescriptor(llvmStruct);
47-
48-
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
49-
operands.push_back(ptr);
50-
51-
return operands;
52-
}
53-
5434
LogicalResult x86vector::MaskCompressOp::verify() {
5535
if (getSrc() && getConstantSrc())
5636
return emitError("cannot use both src and constant_src");
@@ -65,8 +45,8 @@ LogicalResult x86vector::MaskCompressOp::verify() {
6545
return success();
6646
}
6747

68-
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
69-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
48+
SmallVector<Value>
49+
x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
7050
auto loc = getLoc();
7151

7252
auto opType = getA().getType();
@@ -84,8 +64,7 @@ SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
8464
}
8565

8666
SmallVector<Value>
87-
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
88-
const LLVMTypeConverter &typeConverter) {
67+
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
8968
SmallVector<Value> operands(getOperands());
9069
// Dot product of all elements, broadcasted to all elements.
9170
Value scale =
@@ -95,22 +74,5 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
9574
return operands;
9675
}
9776

98-
SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
99-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
101-
}
102-
103-
SmallVector<Value>
104-
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
105-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
106-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
107-
}
108-
109-
SmallVector<Value>
110-
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
111-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
112-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
113-
}
114-
11577
#define GET_OP_CLASSES
11678
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion
9696
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
9797
PatternRewriter &rewriter) const override {
9898
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99-
op.getIntrinsicOperands(rewriter, typeConverter),
100-
typeConverter, rewriter);
99+
op.getIntrinsicOperands(rewriter), typeConverter,
100+
rewriter);
101101
}
102102

103103
private:
@@ -114,8 +114,7 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
114114

115115
void mlir::configureX86VectorLegalizeForExportTarget(
116116
LLVMConversionTarget &target) {
117-
target.addIllegalOp<
118-
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
119-
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
120-
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
117+
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
118+
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
119+
DotOp>();
121120
}

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -95,60 +95,6 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9595
return %0 : vector<16xbf16>
9696
}
9797

98-
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
99-
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
100-
%a: memref<8xbf16>) -> vector<4xf32>
101-
{
102-
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
103-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104-
return %0 : vector<4xf32>
105-
}
106-
107-
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108-
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109-
%a: memref<16xbf16>) -> vector<8xf32>
110-
{
111-
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
112-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
113-
return %0 : vector<8xf32>
114-
}
115-
116-
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
117-
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
118-
%a: memref<8xbf16>) -> vector<4xf32>
119-
{
120-
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
121-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
122-
return %0 : vector<4xf32>
123-
}
124-
125-
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
126-
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
127-
%a: memref<16xbf16>) -> vector<8xf32>
128-
{
129-
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
130-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
131-
return %0 : vector<8xf32>
132-
}
133-
134-
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
135-
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
136-
%a: memref<1xbf16>) -> vector<4xf32>
137-
{
138-
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
139-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
140-
return %0 : vector<4xf32>
141-
}
142-
143-
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
144-
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
145-
%a: memref<1xbf16>) -> vector<8xf32>
146-
{
147-
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
148-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
149-
return %0 : vector<8xf32>
150-
}
151-
15298
// CHECK-LABEL: func @avx_rsqrt
15399
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
154100
{

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -94,66 +94,6 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9494
return %0 : vector<16xbf16>
9595
}
9696

97-
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
98-
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
99-
%a: memref<8xbf16>) -> vector<4xf32>
100-
{
101-
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
102-
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
103-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104-
return %0 : vector<4xf32>
105-
}
106-
107-
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108-
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109-
%a: memref<16xbf16>) -> vector<8xf32>
110-
{
111-
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
112-
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
113-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
114-
return %0 : vector<8xf32>
115-
}
116-
117-
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
118-
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
119-
%a: memref<8xbf16>) -> vector<4xf32>
120-
{
121-
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
122-
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
123-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
124-
return %0 : vector<4xf32>
125-
}
126-
127-
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
128-
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
129-
%a: memref<16xbf16>) -> vector<8xf32>
130-
{
131-
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
132-
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
133-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
134-
return %0 : vector<8xf32>
135-
}
136-
137-
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
138-
func.func @avxbf16_bcst_bf16_to_f32_128(
139-
%a: memref<1xbf16>) -> vector<4xf32>
140-
{
141-
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
142-
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
143-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
144-
return %0 : vector<4xf32>
145-
}
146-
147-
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
148-
func.func @avxbf16_bcst_bf16_to_f32_256(
149-
%a: memref<1xbf16>) -> vector<8xf32>
150-
{
151-
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
152-
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
153-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
154-
return %0 : vector<8xf32>
155-
}
156-
15797
// CHECK-LABEL: func @avx_rsqrt
15898
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
15999
{

0 commit comments

Comments
 (0)