Skip to content

Commit d206153

Browse files
committed
[mlir][vector] Modify constraint and interface for warp reduce on f16 and i8
Quantization method is crucial and ubiqutous in accelerating machine learning workloads. Most of these methods uses f16 and i8 types. This patch relaxes the type contraints on warp reduce distribution to allow these types. Furthermore, this patch also changed the interface and moved the initial reduction of data to a single thread into the distributedReductionFn, this gives flexibility for developers to control how they are obtaining the initial lane value, which might differ based on the input types. (i.e to shuffle 32-width type, we need to reduce f16 to 2xf16 types rather than a single element). Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D137691
1 parent dc9846c commit d206153

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,12 +1135,13 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
11351135
if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
11361136
return rewriter.notifyMatchFailure(
11371137
warpOp, "Reduction vector dimension must match was size.");
1138-
// Only f32 and i32 element types are supported.
1138+
// Only f32, i32, f16, i8 element types are supported.
11391139
if (!reductionOp.getType().isF32() &&
1140-
!reductionOp.getType().isSignlessInteger(32))
1140+
!reductionOp.getType().isSignlessInteger(32) &&
1141+
!reductionOp.getType().isF16() && !reductionOp.getType().isInteger(8))
11411142
return rewriter.notifyMatchFailure(
1142-
warpOp,
1143-
"Reduction distribution currently only supports 32bits types.");
1143+
warpOp, "Reduction distribution currently only supports 32bits, f16, "
1144+
"and i8 types.");
11441145

11451146
int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
11461147
// Return vector that will be reduced from the WarpExecuteOnLane0Op.
@@ -1157,13 +1158,11 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
11571158
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
11581159
rewriter.setInsertionPointAfter(newWarpOp);
11591160

1161+
// Obtain data to reduce for a single lane.
11601162
Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1161-
// First reduce on a single thread.
1162-
Value perLaneReduction = rewriter.create<vector::ReductionOp>(
1163-
reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
1164-
// Then distribute across threads.
1163+
// Distribute and reduce across threads.
11651164
Value fullReduce =
1166-
distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
1165+
distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
11671166
reductionOp.getKind(), newWarpOp.getWarpSize());
11681167
if (reductionOp.getAcc()) {
11691168
fullReduce = vector::makeArithReduction(

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,8 @@ static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
686686

687687
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
688688
CombiningKind kind, uint32_t size) {
689-
Value laneVal = input;
689+
// First reduce on a single thread to get per lane reduction value.
690+
Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
690691
// Parallel reduction using butterfly shuffles.
691692
for (uint64_t i = 1; i < size; i <<= 1) {
692693
Value shuffled = builder

0 commit comments

Comments
 (0)