Skip to content

Commit 7360d5d

Browse files
authored
[mlir][vector] Fix cases with multiple yielded transfer_read ops (#71625)
This fixes two bugs: 1) When deciding whether a transfer read could be propagated out of a warp op, it looked for the first yield operand that was produced by a transfer read. If this transfer read wasn't ready to be distributed, the pattern would not re-check for any other transfer reads that could have been propagated. 2) When dropping dead warp results, we do so by updating the warp op signature and splicing in the old region. This does not add the ops in the body of the warp op back to the pattern applicator's worklist, and thus those operations won't be DCE'd. This is a problem for patterns like the one for transfer reads that will still see the dead operation as a user.
1 parent 451bc3e commit 7360d5d

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,17 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
801801
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
802802
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
803803
PatternRewriter &rewriter) const override {
804-
OpOperand *operand = getWarpResult(
805-
warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
804+
// Try to find a distributable yielded read. Note that this pattern can
805+
// still fail at the end after distribution, in which case this might have
806+
// missed another distributable read.
807+
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
808+
// Don't duplicate transfer_read ops when distributing.
809+
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
810+
});
806811
if (!operand)
807812
return failure();
808813
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
809-
// Don't duplicate transfer_read ops when distributing.
810-
if (!read.getResult().hasOneUse())
811-
return failure();
814+
812815
unsigned operandIndex = operand->getOperandNumber();
813816
Value distributedVal = warpOp.getResult(operandIndex);
814817

@@ -937,6 +940,13 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
937940
// Move the body of the old warpOp to a new warpOp.
938941
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
939942
rewriter, warpOp, newYieldValues, newResultTypes);
943+
944+
// Simplify the new warp op after dropping dead results.
945+
newWarpOp.getBody()->walk([&](Operation *op) {
946+
if (isOpTriviallyDead(op))
947+
rewriter.eraseOp(op);
948+
});
949+
940950
// Replace results of the old warpOp by the new, deduplicated results.
941951
SmallVector<Value> newValues;
942952
newValues.reserve(warpOp->getNumResults());

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,43 @@ func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<409
12561256

12571257
// -----
12581258

1259+
func.func @warp_propagate_multi_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>, vector<1xf32>) {
1260+
%f0 = arith.constant 0.000000e+00 : f32
1261+
%r:2 = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>, vector<1xf32>) {
1262+
%0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
1263+
"some_use"(%0) : (vector<1xf32>) -> ()
1264+
%1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
1265+
vector.yield %0, %1 : vector<1xf32>, vector<1xf32>
1266+
}
1267+
return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
1268+
}
1269+
1270+
// CHECK-PROP-LABEL: func.func @warp_propagate_multi_transfer_read
1271+
// CHECK-PROP: vector.warp_execute_on_lane_0{{.*}} -> (vector<1xf32>)
1272+
// CHECK-PROP: %[[INNER_READ:.+]] = vector.transfer_read
1273+
// CHECK-PROP: "some_use"(%[[INNER_READ]])
1274+
// CHECK-PROP: vector.yield %[[INNER_READ]] : vector<1xf32>
1275+
// CHECK-PROP: vector.transfer_read
1276+
1277+
// -----
1278+
1279+
func.func @warp_propagate_dead_user_multi_read(%laneid: index, %src: memref<4096xf32>, %index: index, %index1: index) -> (vector<1xf32>) {
1280+
%f0 = arith.constant 0.000000e+00 : f32
1281+
%r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
1282+
%0 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
1283+
%1 = vector.transfer_read %src[%index1], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<64xf32>
1284+
%max = arith.maximumf %0, %1 : vector<64xf32>
1285+
vector.yield %max : vector<64xf32>
1286+
}
1287+
return %r : vector<1xf32>
1288+
}
1289+
1290+
// CHECK-PROP-LABEL: func.func @warp_propagate_dead_user_multi_read
1291+
// CHECK-PROP-COUNT-2: vector.transfer_read {{.*}} vector<1xf32>
1292+
// CHECK-PROP: arith.maximumf {{.*}} : vector<1xf32>
1293+
1294+
// -----
1295+
12591296
func.func @warp_propagate_masked_write(%laneid: index, %dest: memref<4096xf32>) {
12601297
%c0 = arith.constant 0 : index
12611298
vector.warp_execute_on_lane_0(%laneid)[32] -> () {

0 commit comments

Comments
 (0)