Skip to content

[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() #93361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2024

Conversation

zhaoshiz
Copy link
Contributor

LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.

LLVM's Vector Predication Intrinsics require an explicit vector length
parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp
multiplied by base vector length, e.g.: for <[4]xf32> we should return:
vscale * 4.
@llvmbot
Copy link
Member

llvmbot commented May 25, 2024

@llvm/pr-subscribers-mlir

Author: Zhaoshi Zheng (zhaoshiz)

Changes

LLVM's Vector Predication Intrinsics require an explicit vector length parameter: https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp multiplied by base vector length, e.g.: for <[4]xf32> we should return: vscale * 4.


Full diff: https://github.com/llvm/llvm-project/pull/93361.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+12-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir (+38)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..18bd9660525b4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
                                      llvmType);
 }
 
-/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
+/// Creates a value with the 1-D vector shape provided in `llvmType`.
 /// This is used as effective vector length by some intrinsics supporting
 /// dynamic vector lengths at runtime.
 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
@@ -532,9 +532,19 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
   auto vShape = vType.getShape();
   assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
 
-  return rewriter.create<LLVM::ConstantOp>(
+  Value vLen = rewriter.create<LLVM::ConstantOp>(
       loc, rewriter.getI32Type(),
       rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
+
+  if (!vType.getScalableDims()[0])
+    return vLen;
+
+  // Create VScale*vShape[0] and return it as vector length.
+  Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
+  vScale = rewriter.create<arith::IndexCastOp>(
+      loc, rewriter.getI32Type(), vScale);
+  vLen = rewriter.create<arith::MulIOp>(loc, vLen, vScale);
+  return vLen;
 }
 
 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
index f98a05f8d17e2..209afa217437b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
 // CHECK:           "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32
 
 
+// -----
+
+func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+  return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_f32_scalable(
+// CHECK-SAME:                              %[[INPUT:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                              %[[MASK:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(4 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[4]xf32>, vector<[4]xi1>, i32) -> f32
+
+
 // -----
 
 func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
@@ -167,6 +186,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
 // CHECK:           "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8
 
 
+// -----
+
+func.func @masked_reduce_add_i8_scalable(%arg0: vector<[16]xi8>, %mask : vector<[16]xi1>) -> i8 {
+  %0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xi8> into i8 } : vector<[16]xi1> -> i8
+  return %0 : i8
+}
+
+// CHECK-LABEL:   func.func @masked_reduce_add_i8_scalable(
+// CHECK-SAME:                             %[[INPUT:.*]]: vector<[16]xi8>,
+// CHECK-SAME:                             %[[MASK:.*]]: vector<[16]xi1>) -> i8 {
+// CHECK:           %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
+// CHECK:           %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
+// CHECK:           %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK:           %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK:           %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
+// CHECK:           %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
+// CHECK:           "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[16]xi8>, vector<[16]xi1>, i32) -> i8
+
+
 // -----
 
 func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {

Copy link

github-actions bot commented May 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@zhaoshiz
Copy link
Contributor Author

gentle ping...

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks. How did you decide what tests to "duplicate"? There seems to be more cases with vector.reduction.

if (!vType.getScalableDims()[0])
return vLen;

// Create VScale*vShape[0] and return it as vector length.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] We tend to write "vscale" rather than VScale. Also, why Shape rather than shape[0] or (even better, referring to a C++ variable): vShape[0]. In fact, you could rename vLen as baseVecLength to make the variable names more descriptive and use that in the comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to vScale * baseVecLength, refering to actual variable names in code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel duplicating all tests is a bit redundant.. the triton test case is using 'add', other vector.mask %mask {vector.reduction ...} tests don't offer additional coverage on code path in mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp, i.e., the part maps vector.reduction to llvm.intr.vp.reduce. is not changed. I've dup-ed some tests

@zhaoshiz
Copy link
Contributor Author

zhaoshiz commented Jun 5, 2024

gentle ping..

@zhaoshiz
Copy link
Contributor Author

gentle ping again..

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really sorry about the delay, Ive been a bit behind with reviews lately :(

One small comment, otherwise LG

@@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32


// -----

func.func @masked_reduce_add_f32_scalable(%arg0: vector<[4]xf32>, %mask : vector<[4]xi1>) -> f32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use identical shapes to what's used in @masked_reduce_add_f32. This way the only thing that changes is "scalability" rather than two things at a time. Same comment for other tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I just updated the tests.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely, thank you for working on this and apologies for the delay, LGTM!

@zhaoshiz zhaoshiz merged commit abcbbe7 into llvm:main Jun 13, 2024
7 checks passed
EthanLuisMcDonough pushed a commit to EthanLuisMcDonough/llvm-project that referenced this pull request Aug 13, 2024
…llvm#93361)

LLVM's Vector Predication Intrinsics require an explicit vector length
parameter:
https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp
multiplied by base vector length, e.g.: for <[4]xf32> we should return:
vscale * 4.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants