@@ -1321,11 +1321,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1321
1321
} else {
1322
1322
distributedVecType = extractSrcType;
1323
1323
}
1324
- // Yield source vector from warp op.
1324
+ // Yield source vector and position (if present) from warp op.
1325
+ SmallVector<Value> additionalResults{extractOp.getVector ()};
1326
+ SmallVector<Type> additionalResultTypes{distributedVecType};
1327
+ if (static_cast <bool >(extractOp.getPosition ())) {
1328
+ additionalResults.push_back (extractOp.getPosition ());
1329
+ additionalResultTypes.push_back (extractOp.getPosition ().getType ());
1330
+ }
1325
1331
Location loc = extractOp.getLoc ();
1326
1332
SmallVector<size_t > newRetIndices;
1327
1333
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1328
- rewriter, warpOp, {extractOp. getVector ()}, {distributedVecType} ,
1334
+ rewriter, warpOp, additionalResults, additionalResultTypes ,
1329
1335
newRetIndices);
1330
1336
rewriter.setInsertionPointAfter (newWarpOp);
1331
1337
Value distributedVec = newWarpOp->getResult (newRetIndices[0 ]);
@@ -1354,14 +1360,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1354
1360
AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
1355
1361
// tid of extracting thread: pos / elementsPerLane
1356
1362
Value broadcastFromTid = rewriter.create <affine::AffineApplyOp>(
1357
- loc, sym0.ceilDiv (elementsPerLane), extractOp.getPosition ());
1363
+ loc, sym0.ceilDiv (elementsPerLane),
1364
+ newWarpOp->getResult (newRetIndices[1 ]));
1358
1365
// Extract at position: pos % elementsPerLane
1359
1366
Value pos =
1360
1367
elementsPerLane == 1
1361
1368
? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1362
1369
: rewriter
1363
- .create <affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1364
- extractOp.getPosition ())
1370
+ .create <affine::AffineApplyOp>(
1371
+ loc, sym0 % elementsPerLane,
1372
+ newWarpOp->getResult (newRetIndices[1 ]))
1365
1373
.getResult ();
1366
1374
Value extracted =
1367
1375
rewriter.create <vector::ExtractElementOp>(loc, distributedVec, pos);
0 commit comments