-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Tighten up application conditions in TransferReadAfter… #143869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Tighten up application conditions in TransferReadAfter… #143869
Conversation
@llvm/pr-subscribers-mlir-vector Author: Nicolas Vasilache (nicolasvasilache) Changes…WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. Full diff: https://github.com/llvm/llvm-project/pull/143869.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp.hasOutOfBoundsDim() ||
- !llvm::isa<RankedTensorType>(readOp.getShapedType()))
- return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
+ // Bail if we need an alias analysis.
+ if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+ return failure();
+ // Bail if we need a bounds analysis.
+ if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+ return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ // This pattern should only catch the broadcast case, the non-broadcast case
+ // should be done separately to keep application conditions clean and
+ // separate.
+ AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+ AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+ bool bcast = !readMap.getBroadcastDims().empty() ||
+ !writeMap.getBroadcastDims().empty();
+ if (!bcast)
+ return failure();
+ // At this point, we know we have a bcast.
+ // The masked case is too complext atm, bail.
+ if (readOp.getMask() || defWrite.getMask())
+ return failure();
+ // If indices are not the same a shift may be required, bail.
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
+
Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+ tensor<?x?xf32>, vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+ tensor<?x?xf32>, vector<4x2x6xf32>
+ return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
|
@llvm/pr-subscribers-mlir Author: Nicolas Vasilache (nicolasvasilache) Changes…WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. Full diff: https://github.com/llvm/llvm-project/pull/143869.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a295bf1eb4d95..32e9fcf6ed044 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4668,12 +4668,15 @@ struct TransferReadAfterWriteToBroadcast
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
- if (readOp.hasOutOfBoundsDim() ||
- !llvm::isa<RankedTensorType>(readOp.getShapedType()))
- return failure();
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
if (!defWrite)
return failure();
+ // Bail if we need an alias analysis.
+ if (!readOp.hasPureTensorSemantics() || !readOp.hasPureTensorSemantics())
+ return failure();
+ // Bail if we need a bounds analysis.
+ if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
+ return failure();
// TODO: If the written transfer chunk is a superset of the read transfer
// chunk we could do an extract_strided_slice.
if (readOp.getTransferChunkAccessed() !=
@@ -4684,15 +4687,27 @@ struct TransferReadAfterWriteToBroadcast
if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
getUnusedDimsBitVector({defWrite.getPermutationMap()}))
return failure();
- if (readOp.getIndices() != defWrite.getIndices() ||
- readOp.getMask() != defWrite.getMask())
+ // This pattern should only catch the broadcast case, the non-broadcast case
+ // should be done separately to keep application conditions clean and
+ // separate.
+ AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
+ AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
+ bool bcast = !readMap.getBroadcastDims().empty() ||
+ !writeMap.getBroadcastDims().empty();
+ if (!bcast)
+ return failure();
+ // At this point, we know we have a bcast.
+ // The masked case is too complext atm, bail.
+ if (readOp.getMask() || defWrite.getMask())
+ return failure();
+ // If indices are not the same a shift may be required, bail.
+ if (readOp.getIndices() != defWrite.getIndices())
return failure();
+
Value vec = defWrite.getVector();
// TODO: loop through the chain of transfer_write if we can prove that they
// don't overlap with the transfer_read. This requires improving
// `isDisjointTransferIndices` helper.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
AffineMap map = readMap.compose(writeMap);
if (map.getNumResults() == 0)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..3bea659ec96be 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1540,6 +1540,43 @@ func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>,
// -----
+// CHECK-LABEL: func @store_to_load_tensor_no_actual_broadcast
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_no_actual_broadcast(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true]} :
+ tensor<?x?xf32>, vector<4x2xf32>
+ return %0 : vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize
+// CHECK-NOT: vector.broadcast
+// CHECK-NOT: vector.transpose
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+func.func @store_to_load_tensor_broadcast_out_of_bounds_should_not_canonicalize(%arg0 : tensor<?x?xf32>,
+ %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> {
+ %c0 = arith.constant 0 : index
+ %cf0 = arith.constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] :
+ vector<4x2xf32>, tensor<?x?xf32>
+ %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true],
+ permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} :
+ tensor<?x?xf32>, vector<4x2x6xf32>
+ return %0 : vector<4x2x6xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @store_to_load_tensor_broadcast_scalable
// CHECK-SAME: (%[[ARG:.*]]: tensor<?xf32>, %[[V0:.*]]: vector<[4]xf32>)
// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<[4]xf32> to vector<6x[4]xf32>
|
|
af7e8a6
to
0210067
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Thanks!
…WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking. The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore. The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous. There are other patterns that move permutation logic: - either into the transfer, or - outside of the transfer Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.
0210067
to
0dd282f
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/30622 Here is the relevant piece of the build log for the reference
|
llvm#143869) …WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking. The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore. The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous. There are other patterns that move permutation logic: - either into the transfer, or - outside of the transfer Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR. Co-authored-by: Nicolas Vasilache <[email protected]>
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
llvm#143869) …WriteToBroadcast The pattern would previously apply in spurious cases and generate incorrect IR. In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking. The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore. The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous. There are other patterns that move permutation logic: - either into the transfer, or - outside of the transfer Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR. Co-authored-by: Nicolas Vasilache <[email protected]>
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…ReadAfter… (llvm#143869)" This reverts commit a1e53bc.
…ReadAfter… (llvm#143869)" This reverts commit d5a79ad.
…eadAfter… (llvm#143869)" This reverts commit e4de74b.
…WriteToBroadcast
The pattern would previously apply in spurious cases and generate incorrect IR.
In the process, we disable the application of this pattern in the case where there is no broadcast; this should be handled separately and may more easily support masking.
The case {no-broadcast, yes-transpose} was previously caught by this pattern and arguably could also generate incorrect IR (and was also untested): this case does not apply anymore.
The last cast {yes-broadcast, yes-transpose} continues to apply but should arguably be removed from the future because creating transposes as part of canonicalization feels dangerous.
There are other patterns that move permutation logic:
Ideally, this would be target-dependent and not a canonicalization (i.e. does your DMA HW allow transpose on the fly or not) but this is beyond the scope of this PR.