Skip to content

Commit fcd6163

Browse files
committed
[mlir][vector] Fix cases with multiple yielded transfer_read ops
This fixes two bugs: 1) When deciding whether a transfer read could be propagated out of a warp op, it looked for the first yield operand that was produced by a transfer read. If this transfer read wasn't ready to be distributed, the pattern would not re-check for any other transfer reads that could have been propagated. 2) When dropping dead warp results, we do so by updating the warp op signature and splicing in the old region. This does not add the ops in the body of the warp op back to the pattern applicator's worklist, and thus those operations won't be DCE'd. This is a problem for patterns like the one for transfer reads that will still see the dead operation as a user.
1 parent 220abb0 commit fcd6163

File tree

2 files changed

+70
-7
lines changed

2 files changed

+70
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -801,13 +801,31 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
801801
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
802802
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
803803
PatternRewriter &rewriter) const override {
804-
OpOperand *operand = getWarpResult(
805-
warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
806-
if (!operand)
807-
return failure();
808-
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
809-
// Don't duplicate transfer_read ops when distributing.
810-
if (!read.getResult().hasOneUse())
804+
// Try to find a distributable yielded read. Note that this pattern can
805+
// still fail at the end after distribution, in which case this might have
806+
// missed another distributable read.
807+
vector::TransferReadOp read;
808+
auto yield = cast<vector::YieldOp>(
809+
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
810+
OpOperand *operand;
811+
for (OpOperand &yieldOperand : yield->getOpOperands()) {
812+
Value yieldValues = yieldOperand.get();
813+
Operation *definedOp = yieldValues.getDefiningOp();
814+
if (!definedOp)
815+
continue;
816+
auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
817+
if (!maybeRead)
818+
continue;
819+
if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
820+
continue;
821+
// Don't duplicate transfer_read ops when distributing.
822+
if (!maybeRead.getResult().hasOneUse())
823+
continue;
824+
read = maybeRead;
825+
operand = &yieldOperand;
826+
break;
827+
}
828+
if (!read)
811829
return failure();
812830
unsigned operandIndex = operand->getOperandNumber();
813831
Value distributedVal = warpOp.getResult(operandIndex);
@@ -913,6 +931,14 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
913931
// Move the body of the old warpOp to a new warpOp.
914932
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
915933
rewriter, warpOp, newYieldValues, newResultTypes);
934+
935+
// Simplify the new warp op after dropping dead results.
936+
auto simplifyFn = [&](Operation *op) {
937+
if (isOpTriviallyDead(op))
938+
rewriter.eraseOp(op);
939+
};
940+
newWarpOp.getBody()->walk(simplifyFn);
941+
916942
// Replace results of the old warpOp by the new, deduplicated results.
917943
SmallVector<Value> newValues;
918944
newValues.reserve(warpOp->getNumResults());

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
12561256

12571257
// -----
12581258

1259+
func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
1260+
%f0 = arith.constant 0.000000e+00 : f32
1261+
%r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
1262+
%0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
1263+
"some_use"(%0) : (vector<1xf32>) -> ()
1264+
%1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
1265+
vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
1266+
}
1267+
return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
1268+
}
1269+
1270+
// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
1271+
// CHECK-PROP: vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
1272+
// CHECK-PROP: %[[INNER_READ:.+]] = vector.transfer_read
1273+
// CHECK-PROP: "some_use"(%[[INNER_READ]])
1274+
// CHECK-PROP: vector.yield %[[INNER_READ]] : vector<1xf32>
1275+
// CHECK-PROP: vector.transfer_read
1276+
1277+
// -----
1278+
1279+
func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
1280+
%f0 = arith.constant 0.000000e+00 : f32
1281+
%r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
1282+
%0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
1283+
%1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
1284+
%max = arith.maximumf %0, %1 : vector<64xf32>
1285+
vector.yield %max : vector<64xf32>
1286+
}
1287+
return %r : vector<1xf32>
1288+
}
1289+
1290+
// CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
1291+
// CHECK-PROP-COUNT-2: vector.transfer_read {{.*}} vector<1xf32>
1292+
// CHECK-PROP: arith.maximumf {{.*}} : vector<1xf32>
1293+
1294+
// -----
1295+
12591296
func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
12601297
%c0 = arith.constant 0 : index
12611298
vector.warp_execute_on_lane_0(%laneid)[32] -> () {

0 commit comments

Comments
 (0)