Skip to content

[mlir][Vector] Add vector bitwidth target to xfer op flattening #81966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 21, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).

This PR adds an optional bitwidth parameter to the vector xfer op
flattening transformation so that the flattening doesn't happen if the
trailing dimension of the read/writen vector is larger than this
bitwidth (i.e., we are already able to fill at least one vector
register with that size).
@llvmbot
Copy link
Member

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).


Full diff: https://github.com/llvm/llvm-project/pull/81966.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+7-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+40-5)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+34-2)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..cb3b3de8051d6f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -328,8 +328,13 @@ void populateDropUnitDimWithShapeCastPatterns(RewritePatternSet &patterns,
 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
 /// to transform multiple small n-D transfers into a larger 1-D transfer where
 /// the memref contiguity properties allow it.
-void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns,
-                                           PatternBenefit benefit = 1);
+///
+/// Flattening is only applied if the bitwidth of the trailing vector dimension
+/// is smaller or equal to `targetVectorBitwidth`.
+void populateFlattenVectorTransferPatterns(
+    RewritePatternSet &patterns,
+    unsigned targetVectorBitwidth = std::numeric_limits<unsigned>::max(),
+    PatternBenefit benefit = 1);
 
 /// Collect a set of patterns that bubble up/down bitcast ops.
 ///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..04e5a816dd91e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
@@ -535,9 +534,17 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferReadPattern
     : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
+                                               unsigned vectorBitwidth,
+                                               PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferReadOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
                                 PatternRewriter &rewriter) const override {
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern
     // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Rewrites contiguous row-major vector.transfer_write ops by inserting
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// already reduced i.e. without unit dims.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+public:
+  FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
+                                                unsigned vectorBitwidth,
+                                                PatternBenefit benefit)
+      : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
+        targetVectorBitwidth(vectorBitwidth) {}
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
                                 PatternRewriter &rewriter) const override {
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
+    if (!vectorType.getElementType().isSignlessIntOrFloat())
+      return failure();
+    unsigned trailingVectorDimBitwidth =
+        vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
+    if (trailingVectorDimBitwidth >= targetVectorBitwidth)
+      return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
+
+private:
+  // Minimum bitwidth that the trailing vector dimension should have after
+  // flattening.
+  unsigned targetVectorBitwidth;
 };
 
 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
+    RewritePatternSet &patterns, unsigned targetVectorBitwidth,
+    PatternBenefit benefit) {
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
-      patterns.getContext(), benefit);
+      patterns.getContext(), targetVectorBitwidth, benefit);
   populateShapeCastFoldingPatterns(patterns, benefit);
   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..5ba3ac824770ce 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,7 +66,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,35 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+      %arg : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>) -> vector<5x4x3x20xi32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i32
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>, vector<5x4x3x20xi32>
+    return %v : vector<5x4x3x20xi32>
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_read(
+//   CHECK-NOT:    tensor.collapse_shape
+
+// -----
+
+func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+      %arg0 : memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>,
+      %arg1 : vector<5x4x3x20xi32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] :
+      vector<5x4x3x20xi32>, memref<5x4x3x20xi32, strided<[24, 6, 20, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:  func.func @trailing_dim_larger_than_target_vector_bitwidth_write(
+//   CHECK-NOT:    tensor.collapse_shape
+
+
+
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b8487f..57d104e80d7243 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns
   }
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateFlattenVectorTransferPatterns(patterns);
+    constexpr unsigned targetVectorBitwidth = 512;
+    populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };

@@ -480,7 +480,8 @@ struct TestFlattenVectorTransferPatterns
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
constexpr unsigned targetVectorBitwidth = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this a pass option and the toggle it in the test to demonstrate how it impacts the pattern?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense! LGTM

Comment on lines 1 to +2
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s --check-prefixex=CHECK,CHECK-DEFAULT
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK,CHECK-128B

This way, you don't need to add CHECK-128B to every "sub-file" in this file and only update checks for e.g. transfer_read_dims_mismatch_non_zero_indices:

// CHECK-DEFAULT-LABEL: func @transfer_read_dims_match_contiguous
//       CHECK-DEFAULT:   memref.collapse_shape
// ...
// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous
//       CHECK-128B:   memref.collapse_shape
// ...

https://llvm.org/docs/CommandGuide/FileCheck.html#options

[nit] Side note. IMHO, Once "prefixes" are involved, "CHECK" in prefix name becomes noise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. I added the check for wider coverage. We can leave tests without check rules for a particular run. The only difference I see is that we could have a single label check that is checked in both runs but that doesn't seem to be what you meant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was referring to the CHECK lines that are duplicated, eg for “ @transfer_read_0d”. But now I see that the duplication is only happening for negative tests, right? Sorry, was reading this in a rush and missed the finer details :(

Comment on lines +79 to 82
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something for the future ...

This example, on a machine with vectors which are 128 bits wide (e.g. Arm) would actually benefit from flattening. With 6 elements, we'd use 1.5 vector registers. And and with 12 elements, we'd use 3. That would be better utilization.

Would that make sense as a TODO? (not asking for it in this patch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.

unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be cases that only partial slice is contiguous. In this case, we could flatten trailing dims. I wonder if we will relax this a little more in the near future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.

Comment on lines +736 to +738
// Minimum bitwidth that the trailing vector dimension should have after
// flattening.
unsigned targetVectorBitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after flattening? It seems not correct with the implementation. With targetVectorBitWidth=128 and vector<1x2x6xi32> type, it becomes vector<12xi32> after flattening; the bit-width of trailing dim is 384. It sounds like we should flatten it with the configuration.

How about update it to Maximum bitwidth that ... before flattening?

@dcaballe dcaballe merged commit 71441ed into llvm:main Feb 21, 2024
@dcaballe dcaballe deleted the flatten-bitwidth branch February 21, 2024 17:22
@dcaballe dcaballe restored the flatten-bitwidth branch April 24, 2024 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants