@@ -804,29 +804,14 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
804
804
// Try to find a distributable yielded read. Note that this pattern can
805
805
// still fail at the end after distribution, in which case this might have
806
806
// missed another distributable read.
807
- vector::TransferReadOp read;
808
- auto yield = cast<vector::YieldOp>(
809
- warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
810
- OpOperand *operand;
811
- for (OpOperand &yieldOperand : yield->getOpOperands ()) {
812
- Value yieldValues = yieldOperand.get ();
813
- Operation *definedOp = yieldValues.getDefiningOp ();
814
- if (!definedOp)
815
- continue ;
816
- auto maybeRead = dyn_cast<vector::TransferReadOp>(definedOp);
817
- if (!maybeRead)
818
- continue ;
819
- if (warpOp.getResult (yieldOperand.getOperandNumber ()).use_empty ())
820
- continue ;
807
+ OpOperand *operand = getWarpResult (warpOp, [](Operation *op) {
821
808
// Don't duplicate transfer_read ops when distributing.
822
- if (!maybeRead.getResult ().hasOneUse ())
823
- continue ;
824
- read = maybeRead;
825
- operand = &yieldOperand;
826
- break ;
827
- }
828
- if (!read)
809
+ return isa<vector::TransferReadOp>(op) && op->hasOneUse ();
810
+ });
811
+ if (!operand)
829
812
return failure ();
813
+ auto read = operand->get ().getDefiningOp <vector::TransferReadOp>();
814
+
830
815
unsigned operandIndex = operand->getOperandNumber ();
831
816
Value distributedVal = warpOp.getResult (operandIndex);
832
817
@@ -933,11 +918,10 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
933
918
rewriter, warpOp, newYieldValues, newResultTypes);
934
919
935
920
// Simplify the new warp op after dropping dead results.
936
- auto simplifyFn = [&](Operation *op) {
921
+ newWarpOp. getBody ()-> walk ( [&](Operation *op) {
937
922
if (isOpTriviallyDead (op))
938
923
rewriter.eraseOp (op);
939
- };
940
- newWarpOp.getBody ()->walk (simplifyFn);
924
+ });
941
925
942
926
// Replace results of the old warpOp by the new, deduplicated results.
943
927
SmallVector<Value> newValues;
0 commit comments