Skip to content

[mlir][Vector] Fix bug in vector xfer op flattening transformation #81964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 22, 2024

Conversation

dcaballe
Copy link
Contributor

It looks like the affine map generated to compute the indices of the collapsed dimensions used the wrong dim size. For indices [idx0][idx1] we computed the collapsed index as idx0*size0 + idx1 instead of idx0*size1 + idx1. This led to correctness issues in convolution tests when enabling this transformation internally.

It looks like the affine map generated to compute the indices of the
collapsed dimensions used the wrong dim size. For indices `[idx0][idx1]` we
computed the collapsed index as `idx0*size0 + idx1` instead of `idx0*size1 + idx1`.
This led to correctness issues in convolution tests when enabling this
transformation internally.
@llvmbot
Copy link
Member

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

It looks like the affine map generated to compute the indices of the collapsed dimensions used the wrong dim size. For indices [idx0][idx1] we computed the collapsed index as idx0*size0 + idx1 instead of idx0*size1 + idx1. This led to correctness issues in convolution tests when enabling this transformation internally.


Full diff: https://github.com/llvm/llvm-project/pull/81964.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+6-2)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+31-3)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b761d1ed888973..5f150be0dd8cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -615,10 +615,14 @@ class FlattenContiguousRowMajorTransferReadPattern
       OpFoldResult offset =
           rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
 
+      auto srcType = dyn_cast<ShapedType>(source.getType());
       for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
-        int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
+        // Multiply each index by the size of the next dimension. The last
+        // dimension (contiguous) is multiplied by one.
+        int64_t nextDimSize =
+            (i == outputRank - 1) ? 1 : srcType.getDimSize(i + 1);
         offset = affine::makeComposedFoldedAffineApply(
-            rewriter, loc, offsetExpr + dim * idxExpr,
+            rewriter, loc, offsetExpr + nextDimSize * idxExpr,
             {offset, transferReadOp.getIndices()[i]});
       }
       if (offset.is<Value>()) {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9976048a3320b6..3025d22eef3623 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -66,14 +66,14 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 6 + s1 * 4)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -99,7 +99,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
                      %m_out: memref<1x2x6xi32>) {
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
   vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
     vector<1x2x6xi32>, memref<1x2x6xi32>
@@ -389,3 +389,31 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+                                              %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  return %8 : vector<2x2xf32>
+}
+
+//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL:    func.func @regression_non_contiguous_dim_read(
+//       CHECK:      %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+//       CHECK:     %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+
+// -----
+
+func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+                                                %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+                                                %idx0 : index, %idx1 : index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL:  func.func @unsupported_non_contiguous_dim_write(
+//   CHECK-NOT:    memref.collapse_shape

Comment on lines 620 to 625
// Multiply each index by the size of the next dimension. The last
// dimension (contiguous) is multiplied by one.
int64_t nextDimSize =
(i == outputRank - 1) ? 1 : srcType.getDimSize(i + 1);
offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr + dim * idxExpr,
rewriter, loc, offsetExpr + nextDimSize * idxExpr,
Copy link
Member

@MacDue MacDue Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking, is this correct for > 2D?

[idx0][idx1][idx2] would be:

idx0 * size1*size2 + idx1 * size2 + idx2

memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 6 + s1 * 4)>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be s0 * 6 + s1 * (4 * 6)? As in, this affine map looks incorrect to me.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have IndexingUtils to avoid reinventing such code every time..

Can we please use that and improve / evolve as needed?

@dcaballe
Copy link
Contributor Author

Thanks! Good catch! Using index utils now. It should ready

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@dcaballe
Copy link
Contributor Author

Thanks! Will land this tomorrow if no more comments.

@dcaballe dcaballe merged commit 847048f into llvm:main Feb 22, 2024
@dcaballe dcaballe deleted the flatten-bug branch February 22, 2024 20:37
dcaballe added a commit to dcaballe/llvm-project that referenced this pull request Feb 22, 2024
This test failed after landing llvm#81964 due to a bad merge. I provided a
quick fix and this PR is adding the rest of CHECK rules that were not
merged properly.
dcaballe added a commit that referenced this pull request Feb 22, 2024
#82698)

This test failed after landing #81964 due to a bad merge. I provided a quick fix and this PR is adding the rest of CHECK rules that were not merged properly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants