Skip to content

[mlir][vector] Fix cases with multiple yielded transfer_read ops #71625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 9, 2023

Conversation

qedawkins
Copy link
Contributor

This fixes two bugs:

  1. When deciding whether a transfer read could be propagated out of
    a warp op, it looked for the first yield operand that was produced by
    a transfer read. If this transfer read wasn't ready to be
    distributed, the pattern would not re-check for any other transfer
    reads that could have been propagated.
  2. When dropping dead warp results, we do so by updating the warp op
    signature and splicing in the old region. This does not add the ops
    in the body of the warp op back to the pattern applicator's worklist,
    and thus those operations won't be DCE'd. This is a problem for
    patterns like the one for transfer reads that will still see the dead
    operation as a user.

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2023

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

This fixes two bugs:

  1. When deciding whether a transfer read could be propagated out of
    a warp op, it looked for the first yield operand that was produced by
    a transfer read. If this transfer read wasn't ready to be
    distributed, the pattern would not re-check for any other transfer
    reads that could have been propagated.
  2. When dropping dead warp results, we do so by updating the warp op
    signature and splicing in the old region. This does not add the ops
    in the body of the warp op back to the pattern applicator's worklist,
    and thus those operations won't be DCE'd. This is a problem for
    patterns like the one for transfer reads that will still see the dead
    operation as a user.

Full diff: https://github.com/llvm/llvm-project/pull/71625.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+33-7)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+37)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..f67e03510ba6ca6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
-    if (!operand)
-      return failure();
-    auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
-    // Don't duplicate transfer_read ops when distributing.
-    if (!read.getResult().hasOneUse())
+    // Try to find a distributable yielded read. Note that this pattern can
+    // still fail at the end after distribution, in which case this might have
+    // missed another distributable read.
+    vector::TransferReadOp read;
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    OpOperand *operand;
+    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+      Value yieldValues = yieldOperand.get();
+      Operation *definedOp = yieldValues.getDefiningOp();
+      if (!definedOp)
+        continue;
+      auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
+      if (!maybeRead)
+        continue;
+      if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        continue;
+      // Don't duplicate transfer_read ops when distributing.
+      if (!maybeRead.getResult().hasOneUse())
+        continue;
+      read = maybeRead;
+      operand = &yieldOperand;
+      break;
+    }
+    if (!read)
       return failure();
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Move the body of the old warpOp to a new warpOp.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newYieldValues, newResultTypes);
+
+    // Simplify the new warp op after dropping dead results.
+    auto simplifyFn = [&](Operation *op) {
+      if (isOpTriviallyDead(op))
+        rewriter.eraseOp(op);
+    };
+    newWarpOp.getBody()->walk(simplifyFn);
+
     // Replace results of the old warpOp by the new, deduplicated results.
     SmallVector<Value> newValues;
     newValues.reserve(warpOp->getNumResults());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..3f95a39100b2b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 
 // -----
 
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    "some_use"(%0) : (vector<1xf32>) -> ()
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+  }
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+//       CHECK-PROP:   vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[INNER_READ:.+]] = vector.transfer_read
+//       CHECK-PROP:     "some_use"(%[[INNER_READ]])
+//       CHECK-PROP:     vector.yield %[[INNER_READ]] : vector<1xf32>
+//       CHECK-PROP:   vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %max = arith.maximumf %0, %1 : vector<64xf32>
+    vector.yield %max : vector<64xf32>
+  }
+  return %r : vector<1xf32>
+}
+
+//   CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2:   vector.transfer_read {{.*}} vector<1xf32>
+//         CHECK-PROP:   arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
   %c0 = arith.constant 0 : index
   vector.warp_execute_on_lane_0(%laneid)[32] -> () {

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2023

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

This fixes two bugs:

  1. When deciding whether a transfer read could be propagated out of
    a warp op, it looked for the first yield operand that was produced by
    a transfer read. If this transfer read wasn't ready to be
    distributed, the pattern would not re-check for any other transfer
    reads that could have been propagated.
  2. When dropping dead warp results, we do so by updating the warp op
    signature and splicing in the old region. This does not add the ops
    in the body of the warp op back to the pattern applicator's worklist,
    and thus those operations won't be DCE'd. This is a problem for
    patterns like the one for transfer reads that will still see the dead
    operation as a user.

Full diff: https://github.com/llvm/llvm-project/pull/71625.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+33-7)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+37)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e128cc71a5d628c..f67e03510ba6ca6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(
-        warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
-    if (!operand)
-      return failure();
-    auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
-    // Don't duplicate transfer_read ops when distributing.
-    if (!read.getResult().hasOneUse())
+    // Try to find a distributable yielded read. Note that this pattern can
+    // still fail at the end after distribution, in which case this might have
+    // missed another distributable read.
+    vector::TransferReadOp read;
+    auto yield = cast<vector::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    OpOperand *operand;
+    for (OpOperand &yieldOperand : yield->getOpOperands()) {
+      Value yieldValues = yieldOperand.get();
+      Operation *definedOp = yieldValues.getDefiningOp();
+      if (!definedOp)
+        continue;
+      auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
+      if (!maybeRead)
+        continue;
+      if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        continue;
+      // Don't duplicate transfer_read ops when distributing.
+      if (!maybeRead.getResult().hasOneUse())
+        continue;
+      read = maybeRead;
+      operand = &yieldOperand;
+      break;
+    }
+    if (!read)
       return failure();
     unsigned operandIndex = operand->getOperandNumber();
     Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
     // Move the body of the old warpOp to a new warpOp.
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
         rewriter, warpOp, newYieldValues, newResultTypes);
+
+    // Simplify the new warp op after dropping dead results.
+    auto simplifyFn = [&](Operation *op) {
+      if (isOpTriviallyDead(op))
+        rewriter.eraseOp(op);
+    };
+    newWarpOp.getBody()->walk(simplifyFn);
+
     // Replace results of the old warpOp by the new, deduplicated results.
     SmallVector<Value> newValues;
     newValues.reserve(warpOp->getNumResults());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index f050bcd246e5ef7..3f95a39100b2b88 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
 
 // -----
 
+func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    "some_use"(%0) : (vector<1xf32>) -> ()
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
+  }
+  return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
+//       CHECK-PROP:   vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
+//       CHECK-PROP:     %[[INNER_READ:.+]] = vector.transfer_read
+//       CHECK-PROP:     "some_use"(%[[INNER_READ]])
+//       CHECK-PROP:     vector.yield %[[INNER_READ]] : vector<1xf32>
+//       CHECK-PROP:   vector.transfer_read
+
+// -----
+
+func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+    %0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
+    %max = arith.maximumf %0, %1 : vector<64xf32>
+    vector.yield %max : vector<64xf32>
+  }
+  return %r : vector<1xf32>
+}
+
+//   CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
+// CHECK-PROP-COUNT-2:   vector.transfer_read {{.*}} vector<1xf32>
+//         CHECK-PROP:   arith.maximumf {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
   %c0 = arith.constant 0 : index
   vector.warp_execute_on_lane_0(%laneid)[32] -> () {

This fixes two bugs:
1) When deciding whether a transfer read could be propagated out of
   a warp op, it looked for the first yield operand that was produced by
   a transfer read. If this transfer read wasn't ready to be
   distributed, the pattern would not re-check for any other transfer
   reads that could have been propagated.
2) When dropping dead warp results, we do so by updating the warp op
   signature and splicing in the old region. This does not add the ops
   in the body of the warp op back to the pattern applicator's worklist,
   and thus those operations won't be DCE'd. This is a problem for
   patterns like the one for transfer reads that will still see the dead
   operation as a user.
@qedawkins qedawkins force-pushed the transfer_read_distribution_fixes branch from 7f9b0f3 to bfa040d Compare November 9, 2023 14:08
@qedawkins qedawkins requested a review from antiagainst November 9, 2023 14:13
@qedawkins qedawkins merged commit 7360d5d into llvm:main Nov 9, 2023
qedawkins added a commit to iree-org/llvm-project that referenced this pull request Nov 10, 2023
…m#71625)

This fixes two bugs:
1) When deciding whether a transfer read could be propagated out of
   a warp op, it looked for the first yield operand that was produced by
   a transfer read. If this transfer read wasn't ready to be
   distributed, the pattern would not re-check for any other transfer
   reads that could have been propagated.
2) When dropping dead warp results, we do so by updating the warp op
   signature and splicing in the old region. This does not add the ops
   in the body of the warp op back to the pattern applicator's worklist,
   and thus those operations won't be DCE'd. This is a problem for
   patterns like the one for transfer reads that will still see the dead
   operation as a user.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…m#71625)

This fixes two bugs:
1) When deciding whether a transfer read could be propagated out of
   a warp op, it looked for the first yield operand that was produced by
   a transfer read. If this transfer read wasn't ready to be
   distributed, the pattern would not re-check for any other transfer
   reads that could have been propagated.
2) When dropping dead warp results, we do so by updating the warp op
   signature and splicing in the old region. This does not add the ops
   in the body of the warp op back to the pattern applicator's worklist,
   and thus those operations won't be DCE'd. This is a problem for
   patterns like the one for transfer reads that will still see the dead
   operation as a user.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants