Skip to content

Commit de5022c

Browse files
[mlir][vector] Implement unrolling of ReductionOp
Differential Revision: https://reviews.llvm.org/D121597
1 parent 20f7f73 commit de5022c

File tree

5 files changed

+81
-2
lines changed

5 files changed

+81
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def Vector_ContractionOp :
266266
def Vector_ReductionOp :
267267
Vector_Op<"reduction", [NoSideEffect,
268268
PredOpTrait<"source operand and result have same element type",
269-
TCresVTEtIsSameAsOpBase<0, 0>>]>,
269+
TCresVTEtIsSameAsOpBase<0, 0>>,
270+
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
271+
["getShapeForUnroll"]>]>,
270272
Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVector:$vector,
271273
Optional<AnyType>:$acc)>,
272274
Results<(outs AnyType:$dest)> {

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,10 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
484484
return nullptr;
485485
}
486486

487+
Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
488+
return llvm::to_vector<4>(getVectorType().getShape());
489+
}
490+
487491
//===----------------------------------------------------------------------===//
488492
// ContractionOp
489493
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,13 +631,60 @@ struct TransferWriteInsertPattern
631631
}
632632
};
633633

634+
struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
635+
UnrollReductionPattern(MLIRContext *context,
636+
const vector::UnrollVectorOptions &options)
637+
: OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
638+
options(options) {}
639+
640+
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
641+
PatternRewriter &rewriter) const override {
642+
Optional<SmallVector<int64_t, 4>> targetShape =
643+
getTargetShape(options, reductionOp);
644+
if (!targetShape)
645+
return failure();
646+
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
647+
int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
648+
649+
// Create unrolled vector reduction.
650+
Location loc = reductionOp.getLoc();
651+
Value accumulator = nullptr;
652+
for (int64_t i = 0; i < ratio; ++i) {
653+
SmallVector<int64_t> offsets =
654+
getVectorOffset(originalSize, *targetShape, i);
655+
SmallVector<int64_t> strides(offsets.size(), 1);
656+
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
657+
loc, reductionOp.vector(), offsets, *targetShape, strides);
658+
Operation *newOp = cloneOpWithOperandsAndTypes(
659+
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
660+
Value result = newOp->getResult(0);
661+
662+
if (!accumulator) {
663+
// This is the first reduction.
664+
accumulator = result;
665+
} else {
666+
// On subsequent reduction, combine with the accumulator.
667+
accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(),
668+
accumulator, result);
669+
}
670+
}
671+
672+
rewriter.replaceOp(reductionOp, accumulator);
673+
return success();
674+
}
675+
676+
private:
677+
const vector::UnrollVectorOptions options;
678+
};
679+
634680
} // namespace
635681

636682
void mlir::vector::populateVectorUnrollPatterns(
637683
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
638684
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
639685
UnrollContractionPattern, UnrollElementwisePattern,
640-
UnrollMultiReductionPattern>(patterns.getContext(), options);
686+
UnrollReductionPattern, UnrollMultiReductionPattern>(
687+
patterns.getContext(), options);
641688
}
642689

643690
void mlir::vector::populatePropagateVectorDistributionPatterns(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,23 @@ func @vector_multi_reduction(%v : vector<4x6xf32>) -> vector<4xf32> {
106106
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[A1]], %[[V0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
107107
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[A3]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
108108
// CHECK: return %[[V2]] : vector<4xf32>
109+
110+
// CHECK-LABEL: func @vector_reduction(
111+
// CHECK-SAME: %[[v:.*]]: vector<8xf32>
112+
// CHECK: %[[s0:.*]] = vector.extract_strided_slice %[[v]] {offsets = [0], sizes = [2]
113+
// CHECK: %[[r0:.*]] = vector.reduction <add>, %[[s0]]
114+
// CHECK: %[[s1:.*]] = vector.extract_strided_slice %[[v]] {offsets = [2], sizes = [2]
115+
// CHECK: %[[r1:.*]] = vector.reduction <add>, %[[s1]]
116+
// CHECK: %[[add1:.*]] = arith.addf %[[r0]], %[[r1]]
117+
// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[v]] {offsets = [4], sizes = [2]
118+
// CHECK: %[[r2:.*]] = vector.reduction <add>, %[[s2]]
119+
// CHECK: %[[add2:.*]] = arith.addf %[[add1]], %[[r2]]
120+
// CHECK: %[[s3:.*]] = vector.extract_strided_slice %[[v]] {offsets = [6], sizes = [2]
121+
// CHECK: %[[r3:.*]] = vector.reduction <add>, %[[s3]]
122+
// CHECK: %[[add3:.*]] = arith.addf %[[add2]], %[[r3]]
123+
// CHECK: return %[[add3]]
124+
func @vector_reduction(%v : vector<8xf32>) -> f32 {
125+
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
126+
return %0 : f32
127+
}
128+

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ struct TestVectorUnrollingPatterns
268268
return success(isa<arith::AddFOp, vector::FMAOp,
269269
vector::MultiDimReductionOp>(op));
270270
}));
271+
populateVectorUnrollPatterns(
272+
patterns, UnrollVectorOptions()
273+
.setNativeShape(ArrayRef<int64_t>{2})
274+
.setFilterConstraint([](Operation *op) {
275+
return success(isa<vector::ReductionOp>(op));
276+
}));
271277

272278
if (unrollBasedOnType) {
273279
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =

0 commit comments

Comments
 (0)