-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][x86vector] AVX512-BF16 Dot op #124800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Adam Siemieniuk (adam-smnk) ChangesAdds AVX512-BF16 operation definitions and a bf16 dot-product operation. Full diff: https://github.com/llvm/llvm-project/pull/124800.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index fa3f0ee0460b1d..409ef9ce16054e 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -271,6 +271,94 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
VectorOfLengthAndType<[8], [I64]>:$b);
}
+//===----------------------------------------------------------------------===//
+// AVX512-BF16 op definitions
+//===----------------------------------------------------------------------===//
+
+// Operation that is part of the input dialect.
+class AVX512BF16_Op<string mnemonic, list<Trait> traits = []> :
+ Op<X86Vector_Dialect, "avx512bf16." # mnemonic, traits> {}
+
+// Intrinsic operation used during lowering to LLVM IR.
+class AVX512BF16_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
+ LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+ "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+ [], [], traits, numResults>;
+
+// Defined by first result overload. May have to be extended for other
+// instructions in the future.
+class AVX512BF16_IntrOverloadedOp<string mnemonic,
+ list<Trait> traits = []> :
+ LLVM_IntrOpBase<X86Vector_Dialect, "avx512bf16.intr." # mnemonic,
+ "x86_avx512bf16_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/[0],
+ /*list<int> overloadedOperands=*/[],
+ traits, /*numResults=*/1>;
+
+//----------------------------------------------------------------------------//
+// AVX512-BF16 Dot
+//----------------------------------------------------------------------------//
+
+def DotBF16Op : AVX512BF16_Op<"dot", [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "dst"]>,
+ TypesMatchWith<"`a` has twice an many elements as `src`",
+ "src", "a",
+ "VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0] * 2}, "
+ "BFloat16Type::get($_self.getContext()))">]> {
+ let summary = "Dot BF16 op";
+ let description = [{
+ The `dot` op is an AVX512-BF16 specific op that can lower to the proper
+ LLVMAVX512BF16 operation `llvm.dpbf16ps` depending on the width of MLIR
+ vectors it is applied to.
+
+ #### From the Intel Intrinsics Guide:
+
+ Compute dot-product of BF16 (16-bit) floating-point pairs in `a` and `b`,
+ accumulating the intermediate single-precision (32-bit) floating-point
+ elements with elements in `src`, and store the results in `dst`.
+
+ Example:
+ ```mlir
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ ```
+ }];
+ let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$a,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>:$b
+ );
+ let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
+ let assemblyFormat =
+ "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
+}
+
+def DotBF16Ps128IntrOp : AVX512BF16_IntrOp<"dpbf16ps.128", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
+ VectorOfLengthAndType<[8], [BF16]>:$a,
+ VectorOfLengthAndType<[8], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
+}
+
+def DotBF16Ps256IntrOp : AVX512BF16_IntrOp<"dpbf16ps.256", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
+ VectorOfLengthAndType<[16], [BF16]>:$a,
+ VectorOfLengthAndType<[16], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
+}
+
+def DotBF16Ps512IntrOp : AVX512BF16_IntrOp<"dpbf16ps.512", 1, [Pure,
+ AllTypesMatch<["a", "b"]>,
+ AllTypesMatch<["src", "res"]>]> {
+ let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
+ VectorOfLengthAndType<[32], [BF16]>:$a,
+ VectorOfLengthAndType<[32], [BF16]>:$b);
+ let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
+}
+
//===----------------------------------------------------------------------===//
// AVX op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index e918473cae9e3a..260ac9ce589a38 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -90,6 +90,47 @@ struct MaskCompressOpConversion
}
};
+struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
+ using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto typeA = dyn_cast<VectorType>(op.getA().getType());
+ unsigned elemBitWidth = typeA.getElementTypeBitWidth();
+ unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
+
+ auto opType = adaptor.getSrc().getType();
+ auto opSrc = adaptor.getSrc();
+ auto opA = adaptor.getA();
+ auto opB = adaptor.getB();
+
+ switch (opBitWidth) {
+ case 128: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ case 256: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ case 512: {
+ rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
+ opB);
+ break;
+ }
+ default: {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported AVX512-BF16 dot variant");
+ }
+ }
+
+ return success();
+ }
+};
+
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
@@ -161,8 +202,8 @@ using Registry = RegistryImpl<
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
- patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
- converter);
+ patterns.add<MaskCompressOpConversion, DotBF16OpConversion, RsqrtOpConversion,
+ DotOpConversion>(converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
@@ -170,6 +211,10 @@ void mlir::configureX86VectorLegalizeForExportTarget(
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
+ target.addLegalOp<DotBF16Ps128IntrOp>();
+ target.addLegalOp<DotBF16Ps256IntrOp>();
+ target.addLegalOp<DotBF16Ps512IntrOp>();
+ target.addIllegalOp<DotBF16Op>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
index 8b9006395fdfe4..cbc8c3051c6ab1 100644
--- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir
@@ -43,6 +43,33 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
}
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+ %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.128
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+ %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.256
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+ %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+ // CHECK: x86vector.avx512bf16.intr.dpbf16ps.512
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir
index 557978b51c5123..f7111f75db6180 100644
--- a/mlir/test/Dialect/X86Vector/roundtrip.mlir
+++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir
@@ -47,6 +47,33 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
}
+// CHECK-LABEL: func @avx512bf16_dot_128
+func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
+ %b: vector<8xbf16>) -> (vector<4xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<8xbf16> -> vector<4xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_256
+func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
+ %b: vector<16xbf16>) -> (vector<8xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<16xbf16> -> vector<8xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @avx512bf16_dot_512
+func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
+ %b: vector<32xbf16>) -> (vector<16xf32>)
+{
+ // CHECK: x86vector.avx512bf16.dot {{.*}} : vector<32xbf16> -> vector<16xf32>
+ %0 = x86vector.avx512bf16.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
// CHECK-LABEL: func @avx_rsqrt
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
{
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
new file mode 100644
index 00000000000000..fe333f49fc8e14
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm="enable-x86vector" -convert-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-translate --mlir-to-llvmir | \
+// RUN: %lli --entry-function=entry --mattr="avx512bf16" --dlopen=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() -> i32 {
+ %i0 = arith.constant 0 : i32
+ %i3 = arith.constant 3 : i32
+
+ %src = arith.constant dense<1.0> : vector<4xf32>
+ %a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xbf16>
+ %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xbf16>
+ %dst = x86vector.avx512bf16.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
+
+ %1 = vector.extractelement %dst[%i0 : i32] : vector<4xf32>
+ %2 = vector.extractelement %dst[%i3 : i32] : vector<4xf32>
+ %d = arith.addf %1, %2 : f32
+
+ // CHECK: ( 30, 82, 150, 234 )
+ // CHECK: 264
+ vector.print %dst : vector<4xf32>
+ vector.print %d : f32
+
+ return %i0 : i32
+}
|
mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/Vector/CPU/X86Vector/dot-bf16.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. FWIW:
- IREE-wise, we are not directly concerned because we use microkernels for matrix multiplication kernels on CPU, e.g. for avx512bf16: https://github.com/iree-org/iree/blob/main/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_bf16.c). Just mentioning why I don't have strong opinions here as a downstream.
- AMD-wise, avx512bf16 is supported on all microarchitectures since Zen4 (2023-ish) but not on earlier CPUs (up to Zen3 inclusive) so there probably are many folks running these tests without avx512bf16 support. I agree with the sentiment that test suites should not crash on older CPUs, so hopefully you can resolve that.
Indeed, our position too. We're now looking into generating simple kernels directly and why we want to resurrect this dialect. The plan is to have micro-kernel quality "special lowering" for particular patterns but still use micro-kernels for the more complex stuff, and slowly cover the space. Once we have some minimal prototype we'll have lessons learned on both dialect and vector level transform, which we'll start working with you guys to upstream. We're hoping this will directly benefit IREE too. |
I'd dare to say nobody runs these tests as their lowering seems incomplete atm ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/13818 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/117/builds/6151 Here is the relevant piece of the build log for the reference
|
The test should've been restricted to x86 target. I'll make a fix. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/55/builds/6297 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/24/builds/4745 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/85/builds/4911 Here is the relevant piece of the build log for the reference
|
I haven't thought too deeply about this but I believe the current structure of the
I wonder if we could find a way to conditionally register operations within the same dialect. This approach would help us implement the second point without having to create an independent dialect for each sub-ISA. However, if this conditional registration is not feasible, creating separate dialects for each sub-ISA doesn’t sound terrible to me, esp. if we plan to add a significant number of operations. It would be great to take some actions on this regard before things get out of control... |
Agree. The idea is not to add "all or a lot of operations", but to add enough (3-4) to have a prototype that guide us through a scalable design.
Conditionally loading operations in a dialect is not something trivial to do, nor it makes sense for cross-compilation, so I'd avoid trying to go there as a first step. I think a simpler step would be to conditionally load entire dialects in response to target (not host) information, but that still need us to have a coherent target description story. the DLTI work is leading us there, but it's still not required (like LLVM's triple/DL), so we can't rely on that yet. My aim is to make DLTI mandatory like LLVM and then we can start making smarter compiler time and run time selections.
It does sound terrible to me. SSE, AVX, AVX2 and AVX512 have a dozen variations each. NEON and SVE has a few more. And none of those are really there for code generation (like LLVM's intrinsics), but as short-cuts to a handful of LLVM intrinsics. As soon as we have specialized lowering transforms that can generate LLVM intrinsics directly, most of those instructions become irrelevant. We do not want to recreate the LLVM pipeline in MLIR.
There is the intention of design behind this PR, we don't want to just dump instructions here. |
To fill this gap, different downstream project came up with their own way to represent target information and pass that to upstream code in some way. For example, we have some AVX2 specific transformations with an API that allows downstream projects to enable them using target information. I believe designing with that in mind would help transition to whatever we end up adopting without major rework.
Well, AVX512 alone has thousands of intrinsics, even when many in that list can be folded together. I think, loading all AVX512 ops vs just the supported subsets could really make a difference. We already have dialects with just a handful of operations so I don't see this being too much different. It's also difficult to know how these operations will be used in the long run. We already have some passes running on SME operations and I wouldn't be surprised if we ended up introducing some transformations/canonicalizations of the AVX2 operations that we generate. My comment is mostly giving visibility to some concerns that were brought to my attention about the current state of some of these dialects. It's great to hear that there is a plan to improve this. I'll be happy to help to the extent possible :) |
That's the idea.
We're not proposing to complete the extension by any stretch of imagination. This would be wrong in too many levels. :)
Transformations, yes. Canonicalizations, no. We want to have a "special lowering" from a I do not believe we should have such a low level dialect in the first place (as @Groverkss said, this would be akin to Rocm), which does not make sense, as you also pointed out. We know that, and agree with the sentiment. So, we expose the problems in the current design and come up with a plan to refactor it. My preference is to not need the
Awesome! The intention was to grab the attention of people who care, so we can have a design upstream across usages, not particular to our own. We have two or three more to go and we'll have all we need for a transform that can convert contractions into efficient micro-kernels at the compiler level. With that example upstream, we can begin to discuss the pros and cons and redesign both We'll be counting on you to help us! |
Adds AVX512 bf16 dot-product operation and defines lowering to LLVM intrinsics.
AVX512 intrinsic operation definition is extended with an optional extension field that allows specifying necessary LLVM mnemonic suffix e.g.,
"bf16"
forx86_avx512bf16_
intrinsics.