Skip to content

Commit 303fdde

Browse files
committed
[mlir] [VectorOps] Rewriting of vector.extract/insert_slices to other vector ops
Summary: Rewrites the extract/insert_slices operation in terms of strided_slice/insert_strided_slice ops with intermediate tuple uses (that should get optimimized away with typical usage). This is done in a separate "pass" to enable testing this particular rewriting in isolation. Reviewers: nicolasvasilache, andydavis1, ftynse Reviewed By: nicolasvasilache Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73295
1 parent e3a7c77 commit 303fdde

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

mlir/include/mlir/Dialect/VectorOps/VectorOps.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ void populateVectorToVectorCanonicalizationPatterns(
4343
void populateVectorToVectorTransformationPatterns(
4444
OwningRewritePatternList &patterns, MLIRContext *context);
4545

46+
/// Collect a set of vector slices transformation patterns:
47+
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
48+
/// Useful for clients that want to express all vector "slices"
49+
/// ops in terms of more elementary vector "slice" ops. If all
50+
/// "produced" tuple values are "consumed" (the most common
51+
/// use for "slices" ops), this lowering removes all tuple related
52+
/// operations as well (through DCE and folding). If tuple values
53+
/// "leak" coming in, however, some tuple related ops will remain.
54+
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
55+
MLIRContext *context);
56+
4657
/// Returns the integer type required for subscripts in the vector dialect.
4758
IntegerType getVectorSubscriptType(Builder &builder);
4859

mlir/lib/Dialect/VectorOps/VectorTransforms.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <type_traits>
1414

1515
#include "mlir/Dialect/AffineOps/AffineOps.h"
16+
#include "mlir/Dialect/StandardOps/Ops.h"
1617
#include "mlir/Dialect/VectorOps/VectorOps.h"
1718
#include "mlir/Dialect/VectorOps/VectorTransforms.h"
1819
#include "mlir/Dialect/VectorOps/VectorUtils.h"
@@ -28,6 +29,7 @@
2829
#include "mlir/IR/PatternMatch.h"
2930
#include "mlir/IR/Types.h"
3031
#include "mlir/Support/Functional.h"
32+
#include "mlir/Support/MathExtras.h"
3133
#include "mlir/Support/STLExtras.h"
3234

3335
#include "llvm/Support/CommandLine.h"
@@ -657,6 +659,131 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
657659
}
658660
};
659661

662+
/// Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp.
663+
/// One:
664+
/// %x = vector.extract_slices %0
665+
/// is replaced by:
666+
/// %a = vector.strided_slice %0
667+
/// %b = vector.strided_slice %0
668+
/// ..
669+
/// %x = vector.tuple %a, %b, ..
670+
class ExtractSlicesOpLowering
671+
: public OpRewritePattern<vector::ExtractSlicesOp> {
672+
public:
673+
using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
674+
675+
// TODO(ajcbik): refactor slice utilities out into VectorUtils.h
676+
PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op,
677+
PatternRewriter &rewriter) const override {
678+
auto loc = op.getLoc();
679+
680+
VectorType vectorType = op.getSourceVectorType();
681+
int64_t rank = vectorType.getRank();
682+
auto shape = vectorType.getShape();
683+
684+
SmallVector<int64_t, 4> sizes;
685+
op.getSizes(sizes);
686+
SmallVector<int64_t, 4> strides;
687+
op.getStrides(strides); // all-ones at the moment
688+
689+
// Compute the number of slices in each dimension.
690+
SmallVector<int64_t, 4> sliceDimCounts(rank);
691+
for (int64_t r = 0; r < rank; ++r)
692+
sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
693+
694+
// For each element in the tuple, generate the proper strided slice.
695+
auto basis = computeStrides(sliceDimCounts);
696+
TupleType tupleType = op.getResultTupleType();
697+
int64_t tupleSize = tupleType.size();
698+
SmallVector<Value, 4> tupleValues(tupleSize);
699+
for (int64_t i = 0; i < tupleSize; ++i) {
700+
// De-linearize w.r.t. 'basis'.
701+
auto vectorOffsets = delinearize(i, basis);
702+
// Convert from unrolled vector-space offsets to element-space offsets.
703+
auto elementOffsets = mlir::functional::zipMap(
704+
[](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
705+
// Compute the size of each slice.
706+
SmallVector<int64_t, 4> sliceSizes(rank);
707+
for (int64_t r = 0; r < rank; ++r)
708+
sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
709+
// Insert in tuple.
710+
tupleValues[i] = rewriter.create<vector::StridedSliceOp>(
711+
loc, op.vector(), elementOffsets, sliceSizes, strides);
712+
}
713+
714+
rewriter.replaceOpWithNewOp<vector::TupleOp>(op, tupleType, tupleValues);
715+
return matchSuccess();
716+
}
717+
};
718+
719+
/// Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
720+
/// One:
721+
/// %x = vector.insert_slices %0
722+
/// is replaced by:
723+
/// %r0 = vector.splat 0
724+
// %t1 = vector.tuple_get %0, 0
725+
/// %r1 = vector.insert_strided_slice %r0, %t1
726+
// %t2 = vector.tuple_get %0, 1
727+
/// %r2 = vector.insert_strided_slice %r1, %t2
728+
/// ..
729+
/// %x = ..
730+
class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
731+
public:
732+
using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
733+
734+
// TODO(ajcbik): refactor slice utilities out into VectorUtils.h
735+
PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op,
736+
PatternRewriter &rewriter) const override {
737+
auto loc = op.getLoc();
738+
739+
VectorType vectorType = op.getResultVectorType();
740+
int64_t rank = vectorType.getRank();
741+
auto shape = vectorType.getShape();
742+
743+
SmallVector<int64_t, 4> sizes;
744+
op.getSizes(sizes);
745+
SmallVector<int64_t, 4> strides;
746+
op.getStrides(strides); // all-ones at the moment
747+
748+
// Compute the number of slices in each dimension.
749+
SmallVector<int64_t, 4> sliceDimCounts(rank);
750+
for (int64_t r = 0; r < rank; ++r)
751+
sliceDimCounts[r] = ceilDiv(shape[r], sizes[r]);
752+
753+
// Prepare result.
754+
auto elemType = vectorType.getElementType();
755+
Value zero = rewriter.create<ConstantOp>(loc, elemType,
756+
rewriter.getZeroAttr(elemType));
757+
Value result = rewriter.create<SplatOp>(loc, vectorType, zero);
758+
759+
// For each element in the tuple, extract the proper strided slice.
760+
auto basis = computeStrides(sliceDimCounts);
761+
TupleType tupleType = op.getSourceTupleType();
762+
int64_t tupleSize = tupleType.size();
763+
SmallVector<Value, 4> tupleValues(tupleSize);
764+
for (int64_t i = 0; i < tupleSize; ++i) {
765+
// De-linearize w.r.t. 'basis'.
766+
auto vectorOffsets = delinearize(i, basis);
767+
// Convert from unrolled vector-space offsets to element-space offsets.
768+
auto elementOffsets = mlir::functional::zipMap(
769+
[](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
770+
// Compute the size of each slice.
771+
SmallVector<int64_t, 4> sliceSizes(rank);
772+
for (int64_t r = 0; r < rank; ++r)
773+
sliceSizes[r] = std::min(sizes[r], shape[r] - elementOffsets[r]);
774+
// Extract from tuple into the result.
775+
auto index = rewriter.getI64IntegerAttr(i);
776+
auto tupleGet = rewriter.create<vector::TupleGetOp>(
777+
loc, tupleType.getType(i), op.getOperand(), index);
778+
result = rewriter.create<vector::InsertStridedSliceOp>(
779+
loc, tupleGet, result, elementOffsets, strides);
780+
}
781+
782+
rewriter.replaceOp(op, result);
783+
return matchSuccess();
784+
}
785+
};
786+
660787
} // namespace
661788

662789
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -666,3 +793,8 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
666793
patterns.insert<SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
667794
context);
668795
}
796+
797+
void mlir::vector::populateVectorSlicesLoweringPatterns(
798+
OwningRewritePatternList &patterns, MLIRContext *context) {
799+
patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
800+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: mlir-opt %s -test-vector-slices-conversion | FileCheck %s
2+
3+
// CHECK-LABEL: func @extract_slices(%arg0: vector<3x3xf32>)
4+
// CHECK: %[[SS:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
5+
// CHECK: return %[[SS]]
6+
7+
func @extract_slices(%arg0: vector<3x3xf32>) -> vector<2x2xf32> {
8+
%0 = vector.extract_slices %arg0, [2, 2], [1, 1]
9+
: vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
10+
%1 = vector.tuple_get %0, 0 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
11+
return %1 : vector<2x2xf32>
12+
}
13+
14+
// CHECK-LABEL: func @insert_slices(%arg0: vector<2x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<1x2xf32>, %arg3: vector<1x1xf32>)
15+
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
16+
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %arg0, %[[C0]] {offsets = [0, 0], strides = [1, 1]}
17+
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %arg1, %[[I0]] {offsets = [0, 2], strides = [1, 1]}
18+
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %arg2, %[[I1]] {offsets = [2, 0], strides = [1, 1]}
19+
// CHECK: %[[I3:.*]] = vector.insert_strided_slice %arg3, %[[I2]] {offsets = [2, 2], strides = [1, 1]}
20+
// CHECK: return %[[I3]]
21+
22+
func @insert_slices(%arg0: vector<2x2xf32>,
23+
%arg1: vector<2x1xf32>,
24+
%arg2: vector<1x2xf32>,
25+
%arg3: vector<1x1xf32>) -> vector<3x3xf32> {
26+
%0 = vector.tuple %arg0, %arg1, %arg2, %arg3
27+
: vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>
28+
%1 = vector.insert_slices %0, [2, 2], [1, 1]
29+
: tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
30+
return %1 : vector<3x3xf32>
31+
}
32+
33+
// CHECK-LABEL: func @extract_insert_slices(%arg0: vector<3x3xf32>)
34+
// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<3x3xf32>
35+
// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
36+
// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
37+
// CHECK: %[[X2:.*]] = vector.strided_slice %arg0 {offsets = [2, 0], sizes = [1, 2], strides = [1, 1]}
38+
// CHECK: %[[X3:.*]] = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [1, 1], strides = [1, 1]}
39+
// CHECK: %[[X4:.*]] = vector.insert_strided_slice %[[X0]], %[[C0]] {offsets = [0, 0], strides = [1, 1]}
40+
// CHECK: %[[X5:.*]] = vector.insert_strided_slice %[[X1]], %[[X4]] {offsets = [0, 2], strides = [1, 1]}
41+
// CHECK: %[[X6:.*]] = vector.insert_strided_slice %[[X2]], %[[X5]] {offsets = [2, 0], strides = [1, 1]}
42+
// CHECK: %[[X7:.*]] = vector.insert_strided_slice %[[X3]], %[[X6]] {offsets = [2, 2], strides = [1, 1]}
43+
// CHECK:return %[[X7]]
44+
45+
func @extract_insert_slices(%arg0: vector<3x3xf32>) -> vector<3x3xf32> {
46+
%0 = vector.extract_slices %arg0, [2, 2], [1, 1]
47+
: vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
48+
%1 = vector.insert_slices %0, [2, 2], [1, 1]
49+
: tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>> into vector<3x3xf32>
50+
return %1 : vector<3x3xf32>
51+
}
52+
53+
// CHECK-LABEL: func @extract_slices_tuple_leaks(%arg0: vector<4xf32>)
54+
// CHECK: %[[X0:.*]] = vector.strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]}
55+
// CHECK: %[[X1:.*]] = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]}
56+
// CHECK: %[[X2:.*]] = vector.tuple %[[X0]], %[[X1]]
57+
// CHECK: return %[[X2]]
58+
59+
func @extract_slices_tuple_leaks(%arg0: vector<4xf32>) -> tuple<vector<2xf32>, vector<2xf32>> {
60+
%0 = vector.extract_slices %arg0, [2], [1] : vector<4xf32> into tuple<vector<2xf32>, vector<2xf32>>
61+
return %0 : tuple<vector<2xf32>, vector<2xf32>>
62+
}
63+

mlir/test/lib/Transforms/TestVectorTransforms.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using namespace mlir;
1818
using namespace mlir::vector;
1919

2020
namespace {
21+
2122
#include "TestVectorTransformPatterns.h.inc"
2223

2324
struct TestVectorToVectorConversion
@@ -31,8 +32,22 @@ struct TestVectorToVectorConversion
3132
applyPatternsGreedily(getFunction(), patterns);
3233
}
3334
};
35+
36+
struct TestVectorSlicesConversion
37+
: public FunctionPass<TestVectorSlicesConversion> {
38+
void runOnFunction() override {
39+
OwningRewritePatternList patterns;
40+
populateVectorSlicesLoweringPatterns(patterns, &getContext());
41+
applyPatternsGreedily(getFunction(), patterns);
42+
}
43+
};
44+
3445
} // end anonymous namespace
3546

3647
static PassRegistration<TestVectorToVectorConversion>
3748
pass("test-vector-to-vector-conversion",
3849
"Test conversion patterns between ops in the vector dialect");
50+
51+
static PassRegistration<TestVectorSlicesConversion> slices_pass(
52+
"test-vector-slices-conversion",
53+
"Test conversion patterns that lower slices ops in the vector dialect");

0 commit comments

Comments
 (0)