-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Fix cases with multiple yielded transfer_read ops #71625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
qedawkins
merged 2 commits into
llvm:main
from
qedawkins:transfer_read_distribution_fixes
Nov 9, 2023
Merged
[mlir][vector] Fix cases with multiple yielded transfer_read ops #71625
qedawkins
merged 2 commits into
llvm:main
from
qedawkins:transfer_read_distribution_fixes
Nov 9, 2023
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir Author: Quinn Dawkins (qedawkins) ChangesThis fixes two bugs:
Full diff: https://github.com/llvm/llvm-project/pull/71625.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..f67e03510ba6ca6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
- if (!operand)
- return failure();
- auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
- // Don't duplicate transfer_read ops when distributing.
- if (!read.getResult().hasOneUse())
+ // Try to find a distributable yielded read. Note that this pattern can
+ // still fail at the end after distribution, in which case this might have
+ // missed another distributable read.
+ vector::TransferReadOp read;
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ OpOperand *operand;
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (!definedOp)
+ continue;
+ auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
+ if (!maybeRead)
+ continue;
+ if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ continue;
+ // Don't duplicate transfer_read ops when distributing.
+ if (!maybeRead.getResult().hasOneUse())
+ continue;
+ read = maybeRead;
+ operand = &yieldOperand;
+ break;
+ }
+ if (!read)
return failure();
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Move the body of the old warpOp to a new warpOp.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newYieldValues, newResultTypes);
+
+ // Simplify the new warp op after dropping dead results.
+ auto simplifyFn = [&](Operation *op) {
+ if (isOpTriviallyDead(op))
+ rewriter.eraseOp(op);
+ };
+ newWarpOp.getBody()->walk(simplifyFn);
+
// Replace results of the old warpOp by the new, deduplicated results.
SmallVector<Value> newValues;
newValues.reserve(warpOp->getNumResults());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..3f95a39100b2b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
// -----
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+ %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+ "some_use"(%0) : (vector<1xf32>) -> ()
+ %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+ vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+ }
+ return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+// CHECK-PROP: vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+// CHECK-PROP: %[[INNER_READ:.+]] = vector.transfer_read
+// CHECK-PROP: "some_use"(%[[INNER_READ]])
+// CHECK-PROP: vector.yield %[[INNER_READ]] : vector<1xf32>
+// CHECK-PROP: vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+ %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+ %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+ %max = arith.maximumf %0, %1 : vector<64xf32>
+ vector.yield %max : vector<64xf32>
+ }
+ return %r : vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-PROP: arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
%c0 = arith.constant 0 : index
vector.warp_execute_on_lane_0(%laneid)[32] -> () {
|
@llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesThis fixes two bugs:
Full diff: https://github.com/llvm/llvm-project/pull/71625.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..f67e03510ba6ca6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- OpOperand *operand = getWarpResult(
- warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
- if (!operand)
- return failure();
- auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
- // Don't duplicate transfer_read ops when distributing.
- if (!read.getResult().hasOneUse())
+ // Try to find a distributable yielded read. Note that this pattern can
+ // still fail at the end after distribution, in which case this might have
+ // missed another distributable read.
+ vector::TransferReadOp read;
+ auto yield = cast<vector::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ OpOperand *operand;
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ Value yieldValues = yieldOperand.get();
+ Operation *definedOp = yieldValues.getDefiningOp();
+ if (!definedOp)
+ continue;
+ auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
+ if (!maybeRead)
+ continue;
+ if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+ continue;
+ // Don't duplicate transfer_read ops when distributing.
+ if (!maybeRead.getResult().hasOneUse())
+ continue;
+ read = maybeRead;
+ operand = &yieldOperand;
+ break;
+ }
+ if (!read)
return failure();
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Move the body of the old warpOp to a new warpOp.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newYieldValues, newResultTypes);
+
+ // Simplify the new warp op after dropping dead results.
+ auto simplifyFn = [&](Operation *op) {
+ if (isOpTriviallyDead(op))
+ rewriter.eraseOp(op);
+ };
+ newWarpOp.getBody()->walk(simplifyFn);
+
// Replace results of the old warpOp by the new, deduplicated results.
SmallVector<Value> newValues;
newValues.reserve(warpOp->getNumResults());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..3f95a39100b2b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
// -----
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+ %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+ "some_use"(%0) : (vector<1xf32>) -> ()
+ %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+ vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+ }
+ return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+// CHECK-PROP: vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+// CHECK-PROP: %[[INNER_READ:.+]] = vector.transfer_read
+// CHECK-PROP: "some_use"(%[[INNER_READ]])
+// CHECK-PROP: vector.yield %[[INNER_READ]] : vector<1xf32>
+// CHECK-PROP: vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+ %f0 = arith.constant 0.000000e+00 : f32
+ %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+ %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+ %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+ %max = arith.maximumf %0, %1 : vector<64xf32>
+ vector.yield %max : vector<64xf32>
+ }
+ return %r : vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2: vector.transfer_read {{.*}} vector<1xf32>
+// CHECK-PROP: arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
%c0 = arith.constant 0 : index
vector.warp_execute_on_lane_0(%laneid)[32] -> () {
|
antiagainst
requested changes
Nov 9, 2023
This fixes two bugs: 1) When deciding whether a transfer read could be propagated out of a warp op, it looked for the first yield operand that was produced by a transfer read. If this transfer read wasn't ready to be distributed, the pattern would not re-check for any other transfer reads that could have been propagated. 2) When dropping dead warp results, we do so by updating the warp op signature and splicing in the old region. This does not add the ops in the body of the warp op back to the pattern applicator's worklist, and thus those operations won't be DCE'd. This is a problem for patterns like the one for transfer reads that will still see the dead operation as a user.
7f9b0f3
to
bfa040d
Compare
antiagainst
approved these changes
Nov 9, 2023
qedawkins
added a commit
to iree-org/llvm-project
that referenced
this pull request
Nov 10, 2023
…m#71625) This fixes two bugs: 1) When deciding whether a transfer read could be propagated out of a warp op, it looked for the first yield operand that was produced by a transfer read. If this transfer read wasn't ready to be distributed, the pattern would not re-check for any other transfer reads that could have been propagated. 2) When dropping dead warp results, we do so by updating the warp op signature and splicing in the old region. This does not add the ops in the body of the warp op back to the pattern applicator's worklist, and thus those operations won't be DCE'd. This is a problem for patterns like the one for transfer reads that will still see the dead operation as a user.
zahiraam
pushed a commit
to zahiraam/llvm-project
that referenced
this pull request
Nov 20, 2023
…m#71625) This fixes two bugs: 1) When deciding whether a transfer read could be propagated out of a warp op, it looked for the first yield operand that was produced by a transfer read. If this transfer read wasn't ready to be distributed, the pattern would not re-check for any other transfer reads that could have been propagated. 2) When dropping dead warp results, we do so by updating the warp op signature and splicing in the old region. This does not add the ops in the body of the warp op back to the pattern applicator's worklist, and thus those operations won't be DCE'd. This is a problem for patterns like the one for transfer reads that will still see the dead operation as a user.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This fixes two bugs:
a warp op, it looked for the first yield operand that was produced by
a transfer read. If this transfer read wasn't ready to be
distributed, the pattern would not re-check for any other transfer
reads that could have been propagated.
signature and splicing in the old region. This does not add the ops
in the body of the warp op back to the pattern applicator's worklist,
and thus those operations won't be DCE'd. This is a problem for
patterns like the one for transfer reads that will still see the dead
operation as a user.