Skip to content

[mlir][Vector] Tighten up application conditions in TransferReadAfter… #143869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

nicolasvasilache
Copy link
Contributor

@nicolasvasilache nicolasvasilache commented Jun 12, 2025

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.

In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

  • either into the transfer, or
  • outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nicolas Vasilache (nicolasvasilache)

Changes

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.


Full diff: https://github.com/llvm/llvm-project/pull/143869.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+22-7)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+37)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp.hasOutOfBoundsDim() ||
-        !llvm::isa<RankedTensorType>(readOp.getShapedType()))
-      return failure();
     auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
+    // Bail if we need an alias analysis.
+    if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+      return failure();
+    // Bail if we need a bounds analysis.
+    if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+      return failure();
     // TODO: If the written transfer chunk is a superset of the read transfer
     // chunk we could do an extract_strided_slice.
     if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
       return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    // This pattern should only catch the broadcast case, the non-broadcast case
+    // should be done separately to keep application conditions clean and
+    // separate.
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    bool bcast = !readMap.getBroadcastDims().empty() ||
+                 !writeMap.getBroadcastDims().empty();
+    if (!bcast)
+      return failure();
+    // At this point, we know we have a bcast.
+    // The masked case is too complext atm, bail.
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
+
     Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

Changes

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.


Full diff: https://github.com/llvm/llvm-project/pull/143869.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+22-7)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+37)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
 
   LogicalResult matchAndRewrite(TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    if (readOp.hasOutOfBoundsDim() ||
-        !llvm::isa<RankedTensorType>(readOp.getShapedType()))
-      return failure();
     auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
     if (!defWrite)
       return failure();
+    // Bail if we need an alias analysis.
+    if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+      return failure();
+    // Bail if we need a bounds analysis.
+    if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+      return failure();
     // TODO: If the written transfer chunk is a superset of the read transfer
     // chunk we could do an extract_strided_slice.
     if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
     if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
         getUnusedDimsBitVector({defWrite.getPermutationMap()}))
       return failure();
-    if (readOp.getIndices() != defWrite.getIndices() ||
-        readOp.getMask() != defWrite.getMask())
+    // This pattern should only catch the broadcast case, the non-broadcast case
+    // should be done separately to keep application conditions clean and
+    // separate.
+    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+    bool bcast = !readMap.getBroadcastDims().empty() ||
+                 !writeMap.getBroadcastDims().empty();
+    if (!bcast)
+      return failure();
+    // At this point, we know we have a bcast.
+    // The masked case is too complext atm, bail.
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
       return failure();
+
     Value vec = defWrite.getVector();
     // TODO: loop through the chain of transfer_write if we can prove that they
     // don't overlap with the transfer_read. This requires improving
     // `isDisjointTransferIndices` helper.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
     AffineMap map = readMap.compose(writeMap);
     if (map.getNumResults() == 0)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<?x?xf32>, vector<4x2xf32>
+  return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+//   CHECK-NOT:   vector.broadcast
+//   CHECK-NOT:   vector.transpose
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+  %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+  %c0 = arith.constant 0 : index
+  %cf0 = arith.constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+    vector<4x2xf32>, tensor<?x?xf32>
+  %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+  permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+    tensor<?x?xf32>, vector<4x2x6xf32>
+  return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
 //       CHECK:   %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@nicolasvasilache nicolasvasilache force-pushed the users/nico/transfer-read-after-write-broadcast-tightening branch from af7e8a6 to 0210067 Compare June 12, 2025 12:58
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense!

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks!

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate incorrect IR.

In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.
@nicolasvasilache nicolasvasilache force-pushed the users/nico/transfer-read-after-write-broadcast-tightening branch from 0210067 to 0dd282f Compare June 12, 2025 14:47
@nicolasvasilache nicolasvasilache merged commit e4de74b into main Jun 12, 2025
7 checks passed
@nicolasvasilache nicolasvasilache deleted the users/nico/transfer-read-after-write-broadcast-tightening branch June 12, 2025 15:11
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jun 12, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-flang-rhel-clang running on ppc64le-flang-rhel-test while building mlir at step 6 "test-build-unified-tree-check-flang".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/30622

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-flang) failure: test (failure)
******************** TEST 'Flang :: Semantics/modfile75.F90' FAILED ********************
Exit Code: 2

Command Output (stderr):
--
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 && /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 | /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90 # RUN: at line 1
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=1 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -c -fhermetic-module-files -DWHICH=2 /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/flang -fc1 -fdebug-unparse /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
+ /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
error: Semantic errors in /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90
/home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90:15:11: error: Must be a constant value
    integer(c_int) n
            ^^^^^
FileCheck error: '<stdin>' is empty.
FileCheck command line:  /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/build/bin/FileCheck /home/buildbots/llvm-external-buildbots/workers/ppc64le-flang-rhel-test/ppc64le-flang-rhel-clang-build/llvm-project/flang/test/Semantics/modfile75.F90

--

********************


tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
llvm#143869)

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate
incorrect IR.

In the process, we disable the application of this pattern in the case
where there is no broadcast; this should be handled separately and may
more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this
pattern and arguably could also generate incorrect IR (and was also
untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but
should arguably be removed from the future because creating transposes
as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e.
does your DMA HW allow transpose on the fly or not) but this is beyond
the scope of this PR.

Co-authored-by: Nicolas Vasilache <[email protected]>
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 19, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 24, 2025
akuhlens pushed a commit to akuhlens/llvm-project that referenced this pull request Jun 24, 2025
llvm#143869)

…WriteToBroadcast

The pattern would previously apply in spurious cases and generate
incorrect IR.

In the process, we disable the application of this pattern in the case
where there is no broadcast; this should be handled separately and may
more easily support masking.

The case {no-broadcast, yes-transpose} was previously caught by this
pattern and arguably could also generate incorrect IR (and was also
untested): this case does not apply anymore.

The last cast {yes-broadcast, yes-transpose} continues to apply but
should arguably be removed from the future because creating transposes
as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:

- either into the transfer, or
- outside of the transfer

Ideally, this would be target-dependent and not a canonicalization (i.e.
does your DMA HW allow transpose on the fly or not) but this is beyond
the scope of this PR.

Co-authored-by: Nicolas Vasilache <[email protected]>
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
fabianmcg added a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
fabianmcg added a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants