Skip to content

Commit bfa040d

Browse files
committed
Cleanup logic getting transfer read warp results
1 parent fcd6163 commit bfa040d

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -804,29 +804,14 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
804804
// Try to find a distributable yielded read. Note that this pattern can
805805
// still fail at the end after distribution, in which case this might have
806806
// missed another distributable read.
807-
vector::TransferReadOp read;
808-
auto yield = cast<vector::YieldOp>(
809-
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
810-
OpOperand *operand;
811-
for (OpOperand &yieldOperand : yield->getOpOperands()) {
812-
Value yieldValues = yieldOperand.get();
813-
Operation *definedOp = yieldValues.getDefiningOp();
814-
if (!definedOp)
815-
continue;
816-
auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
817-
if (!maybeRead)
818-
continue;
819-
if (warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
820-
continue;
807+
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
821808
// Don't duplicate transfer_read ops when distributing.
822-
if (!maybeRead.getResult().hasOneUse())
823-
continue;
824-
read = maybeRead;
825-
operand = &yieldOperand;
826-
break;
827-
}
828-
if (!read)
809+
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
810+
});
811+
if (!operand)
829812
return failure();
813+
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
814+
830815
unsigned operandIndex = operand->getOperandNumber();
831816
Value distributedVal = warpOp.getResult(operandIndex);
832817

@@ -933,11 +918,10 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
933918
rewriter, warpOp, newYieldValues, newResultTypes);
934919

935920
// Simplify the new warp op after dropping dead results.
936-
auto simplifyFn = [&](Operation *op) {
921+
newWarpOp.getBody()->walk([&](Operation *op) {
937922
if (isOpTriviallyDead(op))
938923
rewriter.eraseOp(op);
939-
};
940-
newWarpOp.getBody()->walk(simplifyFn);
924+
});
941925

942926
// Replace results of the old warpOp by the new, deduplicated results.
943927
SmallVector<Value> newValues;

0 commit comments

Comments
 (0)