Skip to content

Commit ad100b3

Browse files
[mlir][vector] Fix dominance error in warp vector distribution (llvm#77771)
This commit fixes a test in `vector-warp-distribute.mlir` when `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled. ``` within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #0 does not dominate this use %1 = vector.extract %0[9] : f32 from vector<64xf32> ^ within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region) "func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({ ^bb0(%arg0: index): %0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({ %7 = "some_def"() : () -> vector<64xf32> %8 = "arith.constant"() <{value = 9 : index}> : () -> index %9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32 "vector.yield"(%9, %7) : (f32, vector<64xf32>) -> () }) : (index) -> (f32, vector<2xf32>) %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index %2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index %3 = "vector.extractelement"(%0#1, %2) : (vector<2xf32>, index) -> f32 %4 = "arith.index_cast"(%1) : (index) -> i32 %5 = "arith.constant"() <{value = 32 : i32}> : () -> i32 %6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1) "func.return"(%6#0) : (f32) -> () }) : () -> () LLVM ERROR: IR failed to verify after pattern application ``` The position at which `vector.extractelement` extracts must also be distributed. The fix in `WarpOpExtractElement` is similar to `WarpOpInsertElement`.
1 parent 460ff58 commit ad100b3

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,11 +1321,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13211321
} else {
13221322
distributedVecType = extractSrcType;
13231323
}
1324-
// Yield source vector from warp op.
1324+
// Yield source vector and position (if present) from warp op.
1325+
SmallVector<Value> additionalResults{extractOp.getVector()};
1326+
SmallVector<Type> additionalResultTypes{distributedVecType};
1327+
if (static_cast<bool>(extractOp.getPosition())) {
1328+
additionalResults.push_back(extractOp.getPosition());
1329+
additionalResultTypes.push_back(extractOp.getPosition().getType());
1330+
}
13251331
Location loc = extractOp.getLoc();
13261332
SmallVector<size_t> newRetIndices;
13271333
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1328-
rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1334+
rewriter, warpOp, additionalResults, additionalResultTypes,
13291335
newRetIndices);
13301336
rewriter.setInsertionPointAfter(newWarpOp);
13311337
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1354,14 +1360,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13541360
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
13551361
// tid of extracting thread: pos / elementsPerLane
13561362
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1357-
loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
1363+
loc, sym0.ceilDiv(elementsPerLane),
1364+
newWarpOp->getResult(newRetIndices[1]));
13581365
// Extract at position: pos % elementsPerLane
13591366
Value pos =
13601367
elementsPerLane == 1
13611368
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
13621369
: rewriter
1363-
.create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1364-
extractOp.getPosition())
1370+
.create<affine::AffineApplyOp>(
1371+
loc, sym0 % elementsPerLane,
1372+
newWarpOp->getResult(newRetIndices[1]))
13651373
.getResult();
13661374
Value extracted =
13671375
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);

0 commit comments

Comments
 (0)