-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() #93361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics. For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.
@llvm/pr-subscribers-mlir Author: Zhaoshi Zheng (zhaoshiz) ChangesLLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics. For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4. Full diff: https://github.com/llvm/llvm-project/pull/93361.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..18bd9660525b4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
llvmType);
}
-/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// Creates a value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
- return rewriter.create<LLVM::ConstantOp>(
+ Value vLen = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+
+ if (!vType.getScalableDims()[0])
+ return vLen;
+
+ // Create VScale*vShape[0] and return it as vector length.
+ Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+ vScale = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), vScale);
+ vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
+ return vLen;
}
/// Helper method to lower a `vector.reduction` op that performs an arithmetic
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index f98a05f8d17e2..209afa217437b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
+// -----
+
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME: %[[INPUT:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[4]xf32>, vector<[4]xi1>, i32) -> f32
+
+
// -----
func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
@@ -167,6 +186,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
+// -----
+
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[16]xi8>, %mask : vector<[16]xi1>) -> i8 {
+ %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xi8> into i8 } : vector<[16]xi1> -> i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xi8>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> i8 {
+// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[16]xi8>, vector<[16]xi1>, i32) -> i8
+
+
// -----
func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
gentle ping... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks. How did you decide what tests to "duplicate"? There seems to be more cases with vector.reduction
.
if (!vType.getScalableDims()[0]) | ||
return vLen; | ||
|
||
// Create VScale*vShape[0] and return it as vector length. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] We tend to write "vscale" rather than VScale. Also, why Shape
rather than shape[0]
or (even better, referring to a C++ variable): vShape[0]
. In fact, you could rename vLen
as baseVecLength
to make the variable names more descriptive and use that in the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to vScale * baseVecLength
, refering to actual variable names in code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel duplicating all tests is a bit redundant.. the triton test case is using 'add', other vector.mask %mask {vector.reduction ...} tests don't offer additional coverage on code path in mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp, i.e., the part maps vector.reduction to llvm.intr.vp.reduce. is not changed. I've dup-ed some tests
gentle ping.. |
gentle ping again.. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really sorry about the delay, Ive been a bit behind with reviews lately :(
One small comment, otherwise LG
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) - | |||
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 | |||
|
|||
|
|||
// ----- | |||
|
|||
func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use identical shapes to what's used in @masked_reduce_add_f32
. This way the only thing that changes is "scalability" rather than two things at a time. Same comment for other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I just updated the tests.
… counterparts of regular vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely, thank you for working on this and apologies for the delay, LGTM!
…llvm#93361) LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics. For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.
LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.
For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.